From 747b5748ec1306cbd6fb01ae58e2cf91678291b0 Mon Sep 17 00:00:00 2001 From: xlycae Date: Fri, 29 May 2026 11:55:23 +0800 Subject: [PATCH 1/6] feat: add native helios distilled support --- app/utils/model_utils.py | 68 ++ configs/helios/helios_distilled_i2v.json | 21 + configs/helios/helios_distilled_t2v.json | 19 + lightx2v/infer.py | 2 + .../input_encoders/hf/helios/__init__.py | 3 + .../models/input_encoders/hf/helios/model.py | 78 ++ lightx2v/models/networks/helios/__init__.py | 3 + lightx2v/models/networks/helios/model.py | 96 ++ .../networks/helios/transformer_helios.py | 819 ++++++++++++++++++ lightx2v/models/runners/helios/__init__.py | 1 + .../models/runners/helios/helios_runner.py | 406 +++++++++ .../models/runners/helios/runtime_utils.py | 43 + lightx2v/models/schedulers/helios/__init__.py | 3 + .../models/schedulers/helios/helios_dmd.py | 331 +++++++ .../models/schedulers/helios/scheduler.py | 47 + .../video_encoders/hf/helios/__init__.py | 3 + .../models/video_encoders/hf/helios/vae.py | 67 ++ lightx2v/pipeline.py | 12 +- lightx2v/utils/set_config.py | 32 + lightx2v/utils/utils.py | 8 + scripts/helios/run_helios_distilled_i2v.sh | 19 + scripts/helios/run_helios_distilled_t2v.sh | 17 + test_cases/test_helios_consistency_helpers.py | 195 +++++ test_cases/test_helios_distilled_support.py | 198 +++++ 24 files changed, 2490 insertions(+), 1 deletion(-) create mode 100644 configs/helios/helios_distilled_i2v.json create mode 100644 configs/helios/helios_distilled_t2v.json create mode 100644 lightx2v/models/input_encoders/hf/helios/__init__.py create mode 100644 lightx2v/models/input_encoders/hf/helios/model.py create mode 100644 lightx2v/models/networks/helios/__init__.py create mode 100644 lightx2v/models/networks/helios/model.py create mode 100644 lightx2v/models/networks/helios/transformer_helios.py create mode 100644 lightx2v/models/runners/helios/__init__.py create mode 100644 lightx2v/models/runners/helios/helios_runner.py create mode 100644 lightx2v/models/runners/helios/runtime_utils.py create mode 100644 lightx2v/models/schedulers/helios/__init__.py create mode 100644 lightx2v/models/schedulers/helios/helios_dmd.py create mode 100644 lightx2v/models/schedulers/helios/scheduler.py create mode 100644 lightx2v/models/video_encoders/hf/helios/__init__.py create mode 100644 lightx2v/models/video_encoders/hf/helios/vae.py create mode 100644 scripts/helios/run_helios_distilled_i2v.sh create mode 100644 scripts/helios/run_helios_distilled_t2v.sh create mode 100644 test_cases/test_helios_consistency_helpers.py create mode 100644 test_cases/test_helios_distilled_support.py diff --git a/app/utils/model_utils.py b/app/utils/model_utils.py index c40181d86..f3b8055e5 100644 --- a/app/utils/model_utils.py +++ b/app/utils/model_utils.py @@ -22,6 +22,7 @@ MS_AVAILABLE = False import gc import importlib.util +import json import re import psutil @@ -471,6 +472,68 @@ def get_quant_scheme(quant_detected, quant_op_val): return f"{quant_detected}-{quant_op_val}" +def _load_json_if_exists(path): + if path and os.path.exists(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + return None + + +def detect_helios_variant(model_path_input): + model_index = _load_json_if_exists(os.path.join(model_path_input, "model_index.json")) or {} + modular_model_index = _load_json_if_exists(os.path.join(model_path_input, "modular_model_index.json")) or {} + + scheduler_name = "" + scheduler_entry = model_index.get("scheduler") + if isinstance(scheduler_entry, list) and len(scheduler_entry) >= 2: + scheduler_name = scheduler_entry[1] + + if not scheduler_name: + scheduler_entry = modular_model_index.get("scheduler") + if isinstance(scheduler_entry, list) and len(scheduler_entry) >= 2: + scheduler_name = scheduler_entry[1] + + is_distilled = bool(model_index.get("is_distilled")) or "Distilled" in (modular_model_index.get("_class_name") or "") or scheduler_name == "HeliosDMDScheduler" + variant = "distilled" if is_distilled else "base" + return variant, scheduler_name or ("HeliosDMDScheduler" if is_distilled else "HeliosScheduler"), model_index, modular_model_index + + +def build_helios(model_path_input): + variant, scheduler_type, model_index, modular_model_index = detect_helios_variant(model_path_input) + transformer_config = _load_json_if_exists(os.path.join(model_path_input, "transformer", "config.json")) or {} + scheduler_config = _load_json_if_exists(os.path.join(model_path_input, "scheduler", "scheduler_config.json")) or {} + + helios_config = { + "model_cls": "helios", + "model_variant": variant, + "is_distilled": variant == "distilled", + "model_path": model_path_input, + "transformer_model_path": os.path.join(model_path_input, "transformer"), + "text_encoder_path": os.path.join(model_path_input, "text_encoder"), + "tokenizer_path": os.path.join(model_path_input, "tokenizer"), + "vae_path": os.path.join(model_path_input, "vae"), + "scheduler_path": os.path.join(model_path_input, "scheduler"), + "scheduler_type": scheduler_type, + "model_index_class": model_index.get("_class_name") or modular_model_index.get("_class_name"), + "guider_config_path": os.path.join(model_path_input, "guider", "guider_config.json"), + "transformer_ode_model_path": os.path.join(model_path_input, "transformer_ode"), + "history_sizes": [16, 2, 1], + "num_latent_frames_per_chunk": 9, + "use_zero_init": False, + "zero_steps": 1, + "is_enable_stage2": False, + "pyramid_num_inference_steps_list": [20, 20, 20], + "is_skip_first_chunk": False, + "is_amplify_first_chunk": False, + "image_noise_sigma_min": 0.111, + "image_noise_sigma_max": 0.135, + "use_dynamic_shifting": scheduler_config.get("use_dynamic_shifting"), + "use_flow_sigmas": scheduler_config.get("use_flow_sigmas"), + } + helios_config.update(transformer_config) + return helios_config + + def build_wan21( model_path_input, dit_path_input, @@ -945,3 +1008,8 @@ def get_model_configs( if lora_configs: config["lora_configs"] = lora_configs return config + elif model_type_input == "Helios": + config = build_helios(model_path_input) + if lora_configs: + config["lora_configs"] = lora_configs + return config diff --git a/configs/helios/helios_distilled_i2v.json b/configs/helios/helios_distilled_i2v.json new file mode 100644 index 000000000..bfa00b210 --- /dev/null +++ b/configs/helios/helios_distilled_i2v.json @@ -0,0 +1,21 @@ +{ + "model_cls": "helios", + "model_variant": "distilled", + "infer_steps": 6, + "target_video_length": 99, + "target_height": 384, + "target_width": 640, + "sample_guide_scale": 1.0, + "enable_cfg": false, + "fps": 24, + "history_sizes": [16, 2, 1], + "num_latent_frames_per_chunk": 9, + "use_zero_init": false, + "zero_steps": 1, + "is_enable_stage2": false, + "pyramid_num_inference_steps_list": [2, 2, 2], + "is_skip_first_chunk": false, + "is_amplify_first_chunk": false, + "image_noise_sigma_min": 0.111, + "image_noise_sigma_max": 0.135 +} diff --git a/configs/helios/helios_distilled_t2v.json b/configs/helios/helios_distilled_t2v.json new file mode 100644 index 000000000..b10fec7b3 --- /dev/null +++ b/configs/helios/helios_distilled_t2v.json @@ -0,0 +1,19 @@ +{ + "model_cls": "helios", + "model_variant": "distilled", + "infer_steps": 6, + "target_video_length": 99, + "target_height": 384, + "target_width": 640, + "sample_guide_scale": 1.0, + "enable_cfg": false, + "fps": 24, + "history_sizes": [16, 2, 1], + "num_latent_frames_per_chunk": 9, + "use_zero_init": false, + "zero_steps": 1, + "is_enable_stage2": false, + "pyramid_num_inference_steps_list": [2, 2, 2], + "is_skip_first_chunk": false, + "is_amplify_first_chunk": false +} diff --git a/lightx2v/infer.py b/lightx2v/infer.py index ac4dc2aa5..71adc40e0 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -12,6 +12,7 @@ # from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 +from lightx2v.models.runners.helios.helios_runner import HeliosRunner # noqa: F401 from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner # noqa: F401 from lightx2v.models.runners.motus.motus_runner import MotusRunner # noqa: F401 @@ -76,6 +77,7 @@ def main(): "wan2.2_animate", "hunyuan_video_1.5", "hunyuan_video_1.5_distill", + "helios", "hunyuan3d", "worldplay_distill", "worldplay_ar", diff --git a/lightx2v/models/input_encoders/hf/helios/__init__.py b/lightx2v/models/input_encoders/hf/helios/__init__.py new file mode 100644 index 000000000..3b42dd490 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/helios/__init__.py @@ -0,0 +1,3 @@ +from lightx2v.models.input_encoders.hf.helios.model import HeliosTextEncoder + +__all__ = ["HeliosTextEncoder"] diff --git a/lightx2v/models/input_encoders/hf/helios/model.py b/lightx2v/models/input_encoders/hf/helios/model.py new file mode 100644 index 000000000..440fc88a6 --- /dev/null +++ b/lightx2v/models/input_encoders/hf/helios/model.py @@ -0,0 +1,78 @@ +import html + +import regex as re +import torch +from transformers import AutoTokenizer, UMT5EncoderModel + +from lightx2v.utils.envs import GET_DTYPE +from lightx2v_platform.base.global_var import AI_DEVICE + +try: + import ftfy +except ImportError: + ftfy = None + + +def basic_clean(text): + if ftfy is not None: + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + return re.sub(r"\s+", " ", text).strip() + + +def prompt_clean(text): + return whitespace_clean(basic_clean(text)) + + +def pack_t5_prompt_embeds(hidden_state, attention_mask, max_sequence_length, num_videos_per_prompt=1, dtype=None, device=None): + device = device or hidden_state.device + dtype = dtype or hidden_state.dtype + prompt_embeds = hidden_state.to(dtype=dtype, device=device) + attention_mask = attention_mask.to(device=device) + seq_lens = attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], + dim=0, + ) + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(len(seq_lens) * num_videos_per_prompt, seq_len, -1) + return prompt_embeds, attention_mask.bool() + + +class HeliosTextEncoder: + def __init__(self, config): + self.config = config + self.device = torch.device("cpu") if config.get("t5_cpu_offload", config.get("cpu_offload", False)) else torch.device(AI_DEVICE) + self.dtype = GET_DTYPE() + self.tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"]) + self.text_encoder = UMT5EncoderModel.from_pretrained(config["text_encoder_path"], torch_dtype=self.dtype).to(self.device) + + def infer(self, prompts, max_sequence_length=None): + max_sequence_length = max_sequence_length or self.config.get("max_sequence_length", 512) + prompts = [prompt_clean(prompt) for prompt in prompts] + text_inputs = self.tokenizer( + prompts, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + input_ids = text_inputs.input_ids.to(self.device) + attention_mask = text_inputs.attention_mask.to(self.device) + hidden_state = self.text_encoder(input_ids, attention_mask).last_hidden_state + return pack_t5_prompt_embeds( + hidden_state, + attention_mask, + max_sequence_length=max_sequence_length, + num_videos_per_prompt=1, + dtype=self.dtype, + device=torch.device(AI_DEVICE), + ) diff --git a/lightx2v/models/networks/helios/__init__.py b/lightx2v/models/networks/helios/__init__.py new file mode 100644 index 000000000..43d90dc59 --- /dev/null +++ b/lightx2v/models/networks/helios/__init__.py @@ -0,0 +1,3 @@ +from lightx2v.models.networks.helios.model import HeliosModel + +__all__ = ["HeliosModel"] diff --git a/lightx2v/models/networks/helios/model.py b/lightx2v/models/networks/helios/model.py new file mode 100644 index 000000000..ad62b6f44 --- /dev/null +++ b/lightx2v/models/networks/helios/model.py @@ -0,0 +1,96 @@ +import os + +import torch +from loguru import logger + +from lightx2v.models.networks.helios.transformer_helios import HeliosTransformer3DModel +from lightx2v.utils.envs import GET_DTYPE +from lightx2v_platform.base.global_var import AI_DEVICE + + +class HeliosModel: + def __init__(self, model_path, config, device): + self.config = config + self.device = device + transformer_path = config.get("transformer_model_path") or model_path + self.transformer = HeliosTransformer3DModel.from_pretrained( + transformer_path, + subfolder=None if os.path.basename(transformer_path) == "transformer" else "transformer", + torch_dtype=GET_DTYPE(), + ).to(device) + self.scheduler = None + self._set_attention_backend() + + def _set_attention_backend(self): + attn_type = self.config.get("attn_type") + if not attn_type: + return + try: + if attn_type == "flash_attn3": + self.transformer.set_attention_backend("_flash_3_hub") + elif attn_type == "flash_attn2": + self.transformer.set_attention_backend("flash_hub") + elif attn_type == "torch_sdpa": + self.transformer.set_attention_backend("sdpa") + except Exception as exc: + logger.warning(f"Failed to set Helios attention backend {attn_type}: {exc}") + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @property + def dtype(self): + return self.transformer.dtype + + def infer_noise( + self, + latents, + timestep, + encoder_hidden_states, + history_inputs, + attention_kwargs=None, + ): + return self.transformer( + hidden_states=latents, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + attention_kwargs=attention_kwargs, + return_dict=False, + **history_inputs, + )[0] + + def infer_cfg( + self, + latents, + timestep, + prompt_embeds, + negative_prompt_embeds, + history_inputs, + guidance_scale, + attention_kwargs=None, + is_cfg_zero_star=False, + use_zero_init=False, + zero_steps=1, + stage_idx=0, + step_idx=0, + ): + with self.transformer.cache_context("cond"): + noise_pred = self.infer_noise(latents, timestep, prompt_embeds, history_inputs, attention_kwargs) + + if guidance_scale <= 1.0 or negative_prompt_embeds is None: + return noise_pred + + with self.transformer.cache_context("uncond"): + noise_uncond = self.infer_noise(latents, timestep, negative_prompt_embeds, history_inputs, attention_kwargs) + + if is_cfg_zero_star: + positive_flat = noise_pred.view(noise_pred.shape[0], -1).float() + negative_flat = noise_uncond.view(noise_uncond.shape[0], -1).float() + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8 + alpha = (dot_product / squared_norm).view(noise_pred.shape[0], *([1] * (noise_pred.ndim - 1))).to(noise_pred.dtype) + if stage_idx == 0 and step_idx <= zero_steps and use_zero_init: + return noise_pred * 0.0 + return noise_uncond * alpha + guidance_scale * (noise_pred - noise_uncond * alpha) + + return noise_uncond + guidance_scale * (noise_pred - noise_uncond) diff --git a/lightx2v/models/networks/helios/transformer_helios.py b/lightx2v/models/networks/helios/transformer_helios.py new file mode 100644 index 000000000..f311be0af --- /dev/null +++ b/lightx2v/models/networks/helios/transformer_helios.py @@ -0,0 +1,819 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin +from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput +from diffusers.models.attention import AttentionMixin, AttentionModuleMixin, FeedForward +from diffusers.models.attention_dispatch import dispatch_attention_fn +from diffusers.models.cache_utils import CacheMixin +from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import FP32LayerNorm +from diffusers.utils import apply_lora_scale, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def pad_for_3d_conv(x, kernel_size): + b, c, t, h, w = x.shape + pt, ph, pw = kernel_size + pad_t = (pt - (t % pt)) % pt + pad_h = (ph - (h % ph)) % ph + pad_w = (pw - (w % pw)) % pw + return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate") + + +def center_down_sample_3d(x, kernel_size): + return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size) + + +def apply_rotary_emb_transposed( + hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, +): + x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1) + out = torch.empty_like(hidden_states) + out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2] + out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2] + return out.type_as(hidden_states) + + +def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor): + # encoder_hidden_states is only passed for cross-attention + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + if attn.fused_projections: + if not attn.is_cross_attention: + # In self-attention layers, we can fuse the entire QKV projection into a single linear + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + else: + # In cross-attention layers, we can only fuse the KV projections into a single linear + query = attn.to_q(hidden_states) + key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1) + else: + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + return query, key, value + + +class HeliosOutputNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False): + super().__init__() + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False) + + def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int): + temb = temb[:, -original_context_length:, :] + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device) + hidden_states = hidden_states[:, -original_context_length:, :] + hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + return hidden_states + + +class HeliosAttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." + ) + + def __call__( + self, + attn: "HeliosAttention", + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + original_context_length: int = None, + ) -> torch.Tensor: + query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)) + key = key.unflatten(2, (attn.heads, -1)) + value = value.unflatten(2, (attn.heads, -1)) + + if rotary_emb is not None: + query = apply_rotary_emb_transposed(query, rotary_emb) + key = apply_rotary_emb_transposed(key, rotary_emb) + + if not attn.is_cross_attention and attn.is_amplify_history: + history_seq_len = hidden_states.shape[1] - original_context_length + + if history_seq_len > 0: + scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0) + if attn.history_scale_mode == "per_head": + scale_key = scale_key.view(1, 1, -1, 1) + key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1) + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), + ) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class HeliosAttention(torch.nn.Module, AttentionModuleMixin): + _default_processor_cls = HeliosAttnProcessor + _available_processors = [HeliosAttnProcessor] + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: int | None = None, + cross_attention_dim_head: int | None = None, + processor=None, + is_cross_attention=None, + is_amplify_history=False, + history_scale_mode="per_head", # [scalar, per_head] + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = torch.nn.ModuleList( + [ + torch.nn.Linear(self.inner_dim, dim, bias=True), + torch.nn.Dropout(dropout), + ] + ) + self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps) + + if is_cross_attention is not None: + self.is_cross_attention = is_cross_attention + else: + self.is_cross_attention = cross_attention_dim_head is not None + + self.set_processor(processor) + + self.is_amplify_history = is_amplify_history + if is_amplify_history: + if history_scale_mode == "scalar": + self.history_key_scale = nn.Parameter(torch.ones(1)) + elif history_scale_mode == "per_head": + self.history_key_scale = nn.Parameter(torch.ones(heads)) + else: + raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}") + self.history_scale_mode = history_scale_mode + self.max_scale = 10.0 + + def fuse_projections(self): + if getattr(self, "fused_projections", False): + return + + if not self.is_cross_attention: + concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_qkv = nn.Linear(in_features, out_features, bias=True) + self.to_qkv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + else: + concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) + concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_kv = nn.Linear(in_features, out_features, bias=True) + self.to_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + if self.added_kv_proj_dim is not None: + concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) + concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data]) + out_features, in_features = concatenated_weights.shape + with torch.device("meta"): + self.to_added_kv = nn.Linear(in_features, out_features, bias=True) + self.to_added_kv.load_state_dict( + {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True + ) + + self.fused_projections = True + + @torch.no_grad() + def unfuse_projections(self): + if not getattr(self, "fused_projections", False): + return + + if hasattr(self, "to_qkv"): + delattr(self, "to_qkv") + if hasattr(self, "to_kv"): + delattr(self, "to_kv") + if hasattr(self, "to_added_kv"): + delattr(self, "to_added_kv") + + self.fused_projections = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + original_context_length: int = None, + **kwargs, + ) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states, + attention_mask, + rotary_emb, + original_context_length, + **kwargs, + ) + + +class HeliosTimeTextEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + is_return_encoder_hidden_states: bool = True, + ): + timestep = self.timesteps_proj(timestep) + + time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype + if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: + timestep = timestep.to(time_embedder_dtype) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + if encoder_hidden_states is not None and is_return_encoder_hidden_states: + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + + return temb, timestep_proj, encoder_hidden_states + + +class HeliosRotaryPosEmbed(nn.Module): + def __init__(self, rope_dim, theta): + super().__init__() + self.DT, self.DY, self.DX = rope_dim + self.theta = theta + self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False) + self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False) + self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False) + + def _get_freqs_base(self, dim): + return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim)) + + @torch.no_grad() + def get_frequency_batched(self, freqs_base, pos): + freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos) + freqs = freqs.repeat_interleave(2, dim=0) + return freqs.cos(), freqs.sin() + + @torch.no_grad() + def _get_spatial_meshgrid(self, height, width, device_str): + device = torch.device(device_str) + grid_y_coords = torch.arange(height, device=device, dtype=torch.float32) + grid_x_coords = torch.arange(width, device=device, dtype=torch.float32) + grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij") + return grid_y, grid_x + + @torch.no_grad() + def forward(self, frame_indices, height, width, device): + batch_size = frame_indices.shape[0] + num_frames = frame_indices.shape[1] + + frame_indices = frame_indices.to(device=device, dtype=torch.float32) + grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device)) + + grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width) + grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1) + grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1) + + freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t) + freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch) + freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch) + + result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0) + + return result.permute(1, 0, 2, 3, 4) + + +@maybe_allow_in_graph +class HeliosTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + guidance_cross_attn: bool = False, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", # [scalar, per_head] + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = HeliosAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + processor=HeliosAttnProcessor(), + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + ) + + # 2. Cross-attention + self.attn2 = HeliosAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + processor=HeliosAttnProcessor(), + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + # 4. Guidance cross-attention + self.guidance_cross_attn = guidance_cross_attn + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + original_context_length: int = None, + ) -> torch.Tensor: + if temb.ndim == 4: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1( + norm_hidden_states, + None, + None, + rotary_emb, + original_context_length, + ) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + if self.guidance_cross_attn: + history_seq_len = hidden_states.shape[1] - original_context_length + + history_hidden_states, hidden_states = torch.split( + hidden_states, [history_seq_len, original_context_length], dim=1 + ) + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states, + None, + None, + original_context_length, + ) + hidden_states = hidden_states + attn_output + hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1) + else: + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states, + None, + None, + original_context_length, + ) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class HeliosTransformer3DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin +): + r""" + A Transformer model for video-like data used in the Helios model. + + Args: + patch_size (`tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + Fixed length for text embeddings. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `512`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + window_size (`tuple[int]`, defaults to `(-1, -1)`): + Window size for local attention (-1 indicates global attention). + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`bool`, defaults to `True`): + Enable query/key normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + add_img_emb (`bool`, defaults to `False`): + Whether to use img_emb. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + """ + + _supports_gradient_checkpointing = True + _skip_layerwise_casting_patterns = [ + "patch_embedding", + "patch_short", + "patch_mid", + "patch_long", + "condition_embedder", + "norm", + ] + _no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"] + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "norm1", + "norm2", + "norm3", + "history_key_scale", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + _repeated_blocks = ["HeliosTransformerBlock"] + _cp_plan = { + # Input split at attn level and ffn level. + "blocks.*.attn1": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*.attn2": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + "blocks.*.ffn": { + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + }, + # Output gather at attn level and ffn level. + **{f"blocks.{i}.attn1": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.attn2": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + **{f"blocks.{i}.ffn": ContextParallelOutput(gather_dim=1, expected_dims=3) for i in range(40)}, + } + + @register_to_config + def __init__( + self, + patch_size: tuple[int, ...] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: str | None = "rms_norm_across_heads", + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + rope_dim: tuple[int, ...] = (44, 42, 42), + rope_theta: float = 10000.0, + guidance_cross_attn: bool = True, + zero_history_timestep: bool = True, + has_multi_term_memory_patch: bool = True, + is_amplify_history: bool = False, + history_scale_mode: str = "per_head", # [scalar, per_head] + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Initial Multi Term Memory Patch + self.zero_history_timestep = zero_history_timestep + if has_multi_term_memory_patch: + self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.patch_mid = nn.Conv3d( + in_channels, + inner_dim, + kernel_size=tuple(2 * p for p in patch_size), + stride=tuple(2 * p for p in patch_size), + ) + self.patch_long = nn.Conv3d( + in_channels, + inner_dim, + kernel_size=tuple(4 * p for p in patch_size), + stride=tuple(4 * p for p in patch_size), + ) + + # 3. Condition embeddings + self.condition_embedder = HeliosTimeTextEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + ) + + # 4. Transformer blocks + self.blocks = nn.ModuleList( + [ + HeliosTransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + qk_norm, + cross_attn_norm, + eps, + added_kv_proj_dim, + guidance_cross_attn=guidance_cross_attn, + is_amplify_history=is_amplify_history, + history_scale_mode=history_scale_mode, + ) + for _ in range(num_layers) + ] + ) + + # 5. Output norm & projection + self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + + self.gradient_checkpointing = False + + @apply_lora_scale("attention_kwargs") + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + # ------------ Stage 1 ------------ + indices_hidden_states=None, + indices_latents_history_short=None, + indices_latents_history_mid=None, + indices_latents_history_long=None, + latents_history_short=None, + latents_history_mid=None, + latents_history_long=None, + return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, + ) -> torch.Tensor | dict[str, torch.Tensor]: + # 1. Input + batch_size = hidden_states.shape[0] + p_t, p_h, p_w = self.config.patch_size + + # 2. Process noisy latents + hidden_states = self.patch_embedding(hidden_states) + _, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape + + if indices_hidden_states is None: + indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1) + + hidden_states = hidden_states.flatten(2).transpose(1, 2) + rotary_emb = self.rope( + frame_indices=indices_hidden_states, + height=post_patch_height, + width=post_patch_width, + device=hidden_states.device, + ) + rotary_emb = rotary_emb.flatten(2).transpose(1, 2) + original_context_length = hidden_states.shape[1] + + # 3. Process short history latents + if latents_history_short is not None and indices_latents_history_short is not None: + latents_history_short = self.patch_short(latents_history_short) + _, _, _, H1, W1 = latents_history_short.shape + latents_history_short = latents_history_short.flatten(2).transpose(1, 2) + + rotary_emb_history_short = self.rope( + frame_indices=indices_latents_history_short, + height=H1, + width=W1, + device=latents_history_short.device, + ) + rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_short, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1) + + # 4. Process mid history latents + if latents_history_mid is not None and indices_latents_history_mid is not None: + latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4)) + latents_history_mid = self.patch_mid(latents_history_mid) + latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2) + + rotary_emb_history_mid = self.rope( + frame_indices=indices_latents_history_mid, + height=H1, + width=W1, + device=latents_history_mid.device, + ) + rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2)) + rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2)) + rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1) + + # 5. Process long history latents + if latents_history_long is not None and indices_latents_history_long is not None: + latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8)) + latents_history_long = self.patch_long(latents_history_long) + latents_history_long = latents_history_long.flatten(2).transpose(1, 2) + + rotary_emb_history_long = self.rope( + frame_indices=indices_latents_history_long, + height=H1, + width=W1, + device=latents_history_long.device, + ) + rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4)) + rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4)) + rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2) + + hidden_states = torch.cat([latents_history_long, hidden_states], dim=1) + rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1) + + history_context_length = hidden_states.shape[1] - original_context_length + + if indices_hidden_states is not None and self.zero_history_timestep: + timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device) + temb_t0, timestep_proj_t0, _ = self.condition_embedder( + timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False + ) + temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1) + timestep_proj_t0 = ( + timestep_proj_t0.unflatten(-1, (6, -1)) + .view(1, 6, 1, -1) + .expand(batch_size, -1, history_context_length, -1) + ) + + temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states) + timestep_proj = timestep_proj.unflatten(-1, (6, -1)) + + if indices_hidden_states is not None and not self.zero_history_timestep: + main_repeat_size = hidden_states.shape[1] + else: + main_repeat_size = original_context_length + temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1) + timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1) + + if indices_hidden_states is not None and self.zero_history_timestep: + temb = torch.cat([temb_t0, temb], dim=1) + timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2) + + if timestep_proj.ndim == 4: + timestep_proj = timestep_proj.permute(0, 2, 1, 3) + + # 6. Transformer blocks + hidden_states = hidden_states.contiguous() + encoder_hidden_states = encoder_hidden_states.contiguous() + rotary_emb = rotary_emb.contiguous() + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + original_context_length, + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + original_context_length, + ) + + # 7. Normalization + hidden_states = self.norm_out(hidden_states, temb, original_context_length) + hidden_states = self.proj_out(hidden_states) + + # 8. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/lightx2v/models/runners/helios/__init__.py b/lightx2v/models/runners/helios/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/lightx2v/models/runners/helios/__init__.py @@ -0,0 +1 @@ + diff --git a/lightx2v/models/runners/helios/helios_runner.py b/lightx2v/models/runners/helios/helios_runner.py new file mode 100644 index 000000000..14504182d --- /dev/null +++ b/lightx2v/models/runners/helios/helios_runner.py @@ -0,0 +1,406 @@ +import gc +import math + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from loguru import logger + +from lightx2v.models.input_encoders.hf.helios import HeliosTextEncoder +from lightx2v.models.networks.helios import HeliosModel +from lightx2v.models.runners.default_runner import DefaultRunner +from lightx2v.models.runners.helios.runtime_utils import apply_image_condition_noise, finalize_video_output, pt_video_output_to_comfy_frames +from lightx2v.models.schedulers.helios import HeliosDistilledScheduler +from lightx2v.models.video_encoders.hf.helios import HeliosVAE +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.envs import GET_DTYPE, GET_RECORDER_MODE +from lightx2v.utils.profiler import ProfilingContext4DebugL1, ProfilingContext4DebugL2 +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v_platform.base.global_var import AI_DEVICE + +torch_device_module = getattr(torch, AI_DEVICE) + + +def calculate_shift(image_seq_len, base_seq_len=256, max_seq_len=4096, base_shift=0.5, max_shift=1.15): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + return image_seq_len * m + b + + +def randn_tensor(shape, generator=None, device=None, dtype=None): + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + if isinstance(generator, list): + shape1 = (1,) + tuple(shape[1:]) + return torch.cat([torch.randn(shape1, generator=g, device=device, dtype=dtype) for g in generator], dim=0) + return torch.randn(shape, generator=generator, device=device, dtype=dtype) + + +@RUNNER_REGISTER("helios") +class HeliosRunner(DefaultRunner): + def __init__(self, config): + super().__init__(config) + self.keep_first_frame = self.config.get("keep_first_frame", True) + self.request_generator = None + + def get_request_generator(self): + if self.request_generator is None: + self.request_generator = torch.Generator(device=AI_DEVICE).manual_seed(self.input_info.seed) + return self.request_generator + + def init_scheduler(self): + self.scheduler = HeliosDistilledScheduler(self.config) + + @ProfilingContext4DebugL2("Load models") + def load_model(self): + self.model = self.load_transformer() + self.text_encoders = self.load_text_encoder() + self.vae_encoder, self.vae_decoder = self.load_vae() + + def load_transformer(self): + return HeliosModel(self.config["model_path"], self.config, self.init_device) + + def load_text_encoder(self): + return [HeliosTextEncoder(self.config)] + + def load_image_encoder(self): + return None + + def load_vae(self): + vae = HeliosVAE(self.config) + return vae, vae + + def init_modules(self): + super().init_modules() + if self.config["task"] not in ["t2v", "i2v"]: + raise NotImplementedError(f"HeliosRunner only supports t2v/i2v, got {self.config['task']}") + if self.config.get("compile"): + raise NotImplementedError("Helios native integration does not support compile yet.") + if self.config.get("enable_low_vram_mode"): + raise NotImplementedError("Helios native integration does not support group offload yet.") + if self.config.get("enable_parallelism"): + raise NotImplementedError("Helios native integration does not support context parallelism yet.") + + def get_latent_shape_with_target_hw(self): + target_height = self.input_info.target_shape[0] if self.input_info.target_shape and len(self.input_info.target_shape) == 2 else self.config["target_height"] + target_width = self.input_info.target_shape[1] if self.input_info.target_shape and len(self.input_info.target_shape) == 2 else self.config["target_width"] + return [ + self.config.get("num_channels_latents", 16), + (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, + int(target_height) // self.config["vae_stride"][1], + int(target_width) // self.config["vae_stride"][2], + ] + + @ProfilingContext4DebugL1( + "Run Text Encoder", + recorder_mode=GET_RECORDER_MODE(), + metrics_func=monitor_cli.lightx2v_run_text_encode_duration, + metrics_labels=["HeliosRunner"], + ) + def run_text_encoder(self, input_info): + prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt + if GET_RECORDER_MODE(): + monitor_cli.lightx2v_input_prompt_len.observe(len(prompt)) + prompt_embeds, _ = self.text_encoders[0].infer([prompt], max_sequence_length=self.config.get("max_sequence_length", 512)) + negative_prompt_embeds = None + if self.config.get("enable_cfg", False): + negative_prompt_embeds, _ = self.text_encoders[0].infer( + [input_info.negative_prompt or ""], + max_sequence_length=self.config.get("max_sequence_length", 512), + ) + return { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + } + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_t2v(self): + self.input_info.latent_shape = self.get_latent_shape_with_target_hw() + text_encoder_output = self.run_text_encoder(self.input_info) + torch_device_module.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": None, + } + + @ProfilingContext4DebugL2("Run Encoders") + def _run_input_encoder_local_i2v(self): + self.input_info.latent_shape = self.get_latent_shape_with_target_hw() + text_encoder_output = self.run_text_encoder(self.input_info) + generator = self.get_request_generator() + image_latents, fake_image_latents = self.vae_encoder.prepare_image_latents( + self.input_info.image_path, + generator=generator, + num_latent_frames_per_chunk=self.config.get("num_latent_frames_per_chunk", 9), + height=self.config["target_height"], + width=self.config["target_width"], + dtype=torch.float32, + ) + image_latents, fake_image_latents = apply_image_condition_noise( + image_latents=image_latents, + fake_image_latents=fake_image_latents, + generator=generator, + device=torch.device(AI_DEVICE), + image_noise_sigma_min=self.config.get("image_noise_sigma_min", 0.111), + image_noise_sigma_max=self.config.get("image_noise_sigma_max", 0.135), + video_noise_sigma_min=self.config.get("video_noise_sigma_min", 0.111), + video_noise_sigma_max=self.config.get("video_noise_sigma_max", 0.135), + ) + torch_device_module.empty_cache() + gc.collect() + return { + "text_encoder_output": text_encoder_output, + "image_encoder_output": { + "image_latents": image_latents, + "fake_image_latents": fake_image_latents, + }, + } + + def sample_block_noise(self, batch_size, channel, num_frames, height, width, patch_size, device, generator): + gamma = self.scheduler.inner.config.gamma + _, ph, pw = patch_size + block_size = ph * pw + cov = torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma + cov += torch.eye(block_size, device=device) * 1e-8 + L = torch.linalg.cholesky(cov.float()) + block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) + z = torch.randn(block_number, block_size, generator=generator, device=device) + noise = z @ L.T + noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) + return noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) + + def _prepare_latents(self, batch_size, num_channels_latents, height, width, num_frames, generator, dtype, device): + num_latent_frames = (num_frames - 1) // self.vae_encoder.vae_scale_factor_temporal + 1 + shape = (batch_size, num_channels_latents, num_latent_frames, int(height) // self.vae_encoder.vae_scale_factor_spatial, int(width) // self.vae_encoder.vae_scale_factor_spatial) + return randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + @ProfilingContext4DebugL2("Run DiT") + def run_main(self): + self.get_video_segment_num() + self.model.set_scheduler(self.scheduler) + self.scheduler.prepare( + seed=self.input_info.seed, + latent_shape=self.input_info.latent_shape, + image_encoder_output=self.inputs["image_encoder_output"], + generator=self.get_request_generator(), + ) + + prompt_embeds = self.inputs["text_encoder_output"]["prompt_embeds"].to(self.model.dtype) + negative_prompt_embeds = self.inputs["text_encoder_output"].get("negative_prompt_embeds") + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(self.model.dtype) + + image_latents = None + fake_image_latents = None + if self.inputs["image_encoder_output"] is not None: + image_latents = self.inputs["image_encoder_output"]["image_latents"].to(torch.float32) + fake_image_latents = self.inputs["image_encoder_output"]["fake_image_latents"].to(torch.float32) + + batch_size = prompt_embeds.shape[0] + device = torch.device(AI_DEVICE) + transformer_dtype = self.model.dtype + num_channels_latents = self.model.transformer.config.in_channels + height = self.config["target_height"] + width = self.config["target_width"] + target_video_length = self.config["target_video_length"] + history_sizes = sorted(self.config.get("history_sizes", [16, 2, 1]), reverse=True) + num_latent_frames_per_chunk = self.config.get("num_latent_frames_per_chunk", 9) + pyramid_num_inference_steps_list = self.config.get("pyramid_num_inference_steps_list", [2, 2, 2]) + guidance_scale = self.config.get("sample_guide_scale", 1.0) + use_zero_init = self.config.get("use_zero_init", False) + zero_steps = self.config.get("zero_steps", 1) + is_skip_first_chunk = self.config.get("is_skip_first_chunk", False) + is_amplify_first_chunk = self.config.get("is_amplify_first_chunk", False) + attention_kwargs = None + + window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_encoder.vae_scale_factor_temporal + 1 + num_latent_chunk = max(1, (target_video_length + window_num_frames - 1) // window_num_frames) + num_history_latent_frames = sum(history_sizes) + history_video = None + total_generated_latent_frames = 0 + + history_latents = torch.zeros( + batch_size, + num_channels_latents, + num_history_latent_frames, + height // self.vae_encoder.vae_scale_factor_spatial, + width // self.vae_encoder.vae_scale_factor_spatial, + device=device, + dtype=torch.float32, + ) + if fake_image_latents is not None: + history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents.to(device)], dim=2) + total_generated_latent_frames += 1 + + if self.keep_first_frame: + indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk]), device=device) + ( + indices_prefix, + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_1x, + indices_hidden_states, + ) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0) + indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0) + else: + indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk]), device=device) + ( + indices_latents_history_long, + indices_latents_history_mid, + indices_latents_history_short, + indices_hidden_states, + ) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0) + indices_hidden_states = indices_hidden_states.unsqueeze(0) + indices_latents_history_short = indices_latents_history_short.unsqueeze(0) + indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0) + indices_latents_history_long = indices_latents_history_long.unsqueeze(0) + + for chunk_idx in range(num_latent_chunk): + is_first_chunk = chunk_idx == 0 + is_second_chunk = chunk_idx == 1 + if self.keep_first_frame: + latents_history_long, latents_history_mid, latents_history_1x = history_latents[:, :, -num_history_latent_frames:].split(history_sizes, dim=2) + if image_latents is None and is_first_chunk: + latents_prefix = torch.zeros((batch_size, num_channels_latents, 1, latents_history_1x.shape[-2], latents_history_1x.shape[-1]), device=device, dtype=latents_history_1x.dtype) + else: + latents_prefix = image_latents.to(device) + latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) + else: + latents_history_long, latents_history_mid, latents_history_short = history_latents[:, :, -num_history_latent_frames:].split(history_sizes, dim=2) + + latents = self._prepare_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=window_num_frames, + generator=self.scheduler.generator, + dtype=torch.float32, + device=device, + ) + num_inference_steps = sum(pyramid_num_inference_steps_list) * 2 if is_amplify_first_chunk and is_first_chunk else sum(pyramid_num_inference_steps_list) + _, _, _, pyramid_height, pyramid_width = latents.shape + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width) + for _ in range(len(pyramid_num_inference_steps_list) - 1): + pyramid_height //= 2 + pyramid_width //= 2 + latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="bilinear") * 2 + latents = latents.reshape(batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width).permute(0, 2, 1, 3, 4) + start_point_list = [latents] + completed_steps = 0 + + for stage_idx, stage_steps in enumerate(pyramid_num_inference_steps_list): + patch_size = self.model.transformer.config.patch_size + image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // (patch_size[0] * patch_size[1] * patch_size[2]) + mu = calculate_shift( + image_seq_len, + self.scheduler.inner.config.get("base_image_seq_len", 256), + self.scheduler.inner.config.get("max_image_seq_len", 4096), + self.scheduler.inner.config.get("base_shift", 0.5), + self.scheduler.inner.config.get("max_shift", 1.15), + ) + self.scheduler.set_timesteps(stage_steps, stage_idx, device=device, mu=mu, is_amplify_first_chunk=is_amplify_first_chunk and is_first_chunk) + timesteps = self.scheduler.timesteps + + if stage_idx > 0: + pyramid_height *= 2 + pyramid_width *= 2 + latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_latent_frames_per_chunk, num_channels_latents, pyramid_height // 2, pyramid_width // 2) + latents = F.interpolate(latents, size=(pyramid_height, pyramid_width), mode="nearest") + latents = latents.reshape(batch_size, num_latent_frames_per_chunk, num_channels_latents, pyramid_height, pyramid_width).permute(0, 2, 1, 3, 4) + ori_sigma = 1 - self.scheduler.inner.ori_start_sigmas[stage_idx] + gamma = self.scheduler.inner.config.gamma + alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) + beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) + noise = self.sample_block_noise(batch_size, num_channels_latents, latents.shape[2], pyramid_height, pyramid_width, patch_size, device, self.scheduler.generator).to(dtype=transformer_dtype) + latents = alpha * latents + beta * noise + start_point_list.append(latents) + + for step_idx, timestep_scalar in enumerate(timesteps): + timestep = timestep_scalar.expand(latents.shape[0]).to(torch.int64) + history_inputs = { + "indices_hidden_states": indices_hidden_states, + "indices_latents_history_short": indices_latents_history_short, + "indices_latents_history_mid": indices_latents_history_mid, + "indices_latents_history_long": indices_latents_history_long, + "latents_history_short": latents_history_short.to(transformer_dtype), + "latents_history_mid": latents_history_mid.to(transformer_dtype), + "latents_history_long": latents_history_long.to(transformer_dtype), + } + noise_pred = self.model.infer_cfg( + latents.to(transformer_dtype), + timestep, + prompt_embeds, + negative_prompt_embeds, + history_inputs, + guidance_scale=guidance_scale, + attention_kwargs=attention_kwargs, + is_cfg_zero_star=self.config.get("is_cfg_zero_star", False), + use_zero_init=use_zero_init, + zero_steps=zero_steps, + stage_idx=stage_idx, + step_idx=step_idx, + ) + latents = self.scheduler.step( + noise_pred, + timestep_scalar, + latents, + generator=self.scheduler.generator, + return_dict=False, + cur_sampling_step=step_idx, + dmd_noisy_tensor=start_point_list[stage_idx], + dmd_sigmas=self.scheduler.sigmas, + dmd_timesteps=self.scheduler.timesteps, + all_timesteps=timesteps, + )[0] + completed_steps += 1 + if self.progress_callback: + self.progress_callback((completed_steps / num_inference_steps) * 100, 100) + + if self.keep_first_frame and ((is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk)): + image_latents = latents[:, :, 0:1, :, :] + + total_generated_latent_frames += latents.shape[2] + history_latents = torch.cat([history_latents, latents], dim=2) + real_history_latents = history_latents[:, :, -total_generated_latent_frames:] + current_latents = real_history_latents[:, :, -num_latent_frames_per_chunk:] + current_video = self.vae_decoder.decode(current_latents) + history_video = current_video if history_video is None else torch.cat([history_video, current_video], dim=2) + + self.gen_video = history_video + self.gen_video_final = pt_video_output_to_comfy_frames( + finalize_video_output( + history_video=self.gen_video, + video_processor=self.vae_decoder.video_processor, + temporal_scale_factor=self.vae_decoder.vae_scale_factor_temporal, + output_type="pt", + ) + ) + result = self.process_images_after_vae_decoder_helios() + self.end_run() + return result + + def process_images_after_vae_decoder_helios(self): + if "video_frame_interpolation" in self.config: + assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None + target_fps = self.config["video_frame_interpolation"]["target_fps"] + logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}") + self.gen_video_final = self.vfi_model.interpolate_frames( + self.gen_video_final, + source_fps=self.config.get("fps", 16), + target_fps=target_fps, + ) + + if self.input_info.return_result_tensor: + return {"video": self.gen_video_final} + elif self.input_info.save_result_path is not None: + fps = self.config["video_frame_interpolation"]["target_fps"] if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps") else self.config.get("fps", 16) + if not dist.is_initialized() or dist.get_rank() == 0: + out_path = self.input_info.save_result_path + logger.info("🎬 Start to save video 🎬") + from lightx2v.utils.utils import save_to_video + + save_to_video(self.gen_video_final, out_path, fps=fps, method="ffmpeg") + logger.info(f"✅ Video saved successfully to: {out_path} ✅") + return {"video": None} diff --git a/lightx2v/models/runners/helios/runtime_utils.py b/lightx2v/models/runners/helios/runtime_utils.py new file mode 100644 index 000000000..b8776045f --- /dev/null +++ b/lightx2v/models/runners/helios/runtime_utils.py @@ -0,0 +1,43 @@ +import torch + + +def apply_image_condition_noise( + image_latents, + fake_image_latents, + generator, + device, + image_noise_sigma_min, + image_noise_sigma_max, + video_noise_sigma_min, + video_noise_sigma_max, +): + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * torch.randn(image_latents.shape, generator=generator, device=device) + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * torch.randn(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + return image_latents, fake_image_latents + + +def trim_generated_frames(frame_count, temporal_scale_factor): + return ((frame_count - 1) // temporal_scale_factor) * temporal_scale_factor + 1 + + +def finalize_video_output(history_video, video_processor, temporal_scale_factor, output_type="pt"): + generated_frames = trim_generated_frames(history_video.size(2), temporal_scale_factor) + history_video = history_video[:, :, :generated_frames] + return video_processor.postprocess_video(history_video, output_type=output_type) + + +def pt_video_output_to_comfy_frames(video): + if video.dim() != 5: + raise ValueError(f"Expected [B, T, C, H, W] tensor, got shape {tuple(video.shape)}") + return video.permute(0, 1, 3, 4, 2).flatten(0, 1).cpu() diff --git a/lightx2v/models/schedulers/helios/__init__.py b/lightx2v/models/schedulers/helios/__init__.py new file mode 100644 index 000000000..9dcbc6794 --- /dev/null +++ b/lightx2v/models/schedulers/helios/__init__.py @@ -0,0 +1,3 @@ +from lightx2v.models.schedulers.helios.scheduler import HeliosDistilledScheduler + +__all__ = ["HeliosDistilledScheduler"] diff --git a/lightx2v/models/schedulers/helios/helios_dmd.py b/lightx2v/models/schedulers/helios/helios_dmd.py new file mode 100644 index 000000000..3e3d330d5 --- /dev/null +++ b/lightx2v/models/schedulers/helios/helios_dmd.py @@ -0,0 +1,331 @@ +# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + + +@dataclass +class HeliosDMDSchedulerOutput(BaseOutput): + prev_sample: torch.FloatTensor + model_outputs: torch.FloatTensor | None = None + last_sample: torch.FloatTensor | None = None + this_order: int | None = None + + +class HeliosDMDScheduler(SchedulerMixin, ConfigMixin): + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, # Following Stable diffusion 3, + stages: int = 3, + stage_range: list = [0, 1 / 3, 2 / 3, 1], + gamma: float = 1 / 3, + prediction_type: str = "flow_prediction", + use_flow_sigmas: bool = True, + use_dynamic_shifting: bool = False, + time_shift_type: Literal["exponential", "linear"] = "linear", + ): + self.timestep_ratios = {} # The timestep ratio for each stage + self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage) + self.sigmas_per_stage = {} # always uniform [1000, 0] + self.start_sigmas = {} # for start point / upsample renoise + self.end_sigmas = {} # for end point + self.ori_start_sigmas = {} + + # self.init_sigmas() + self.init_sigmas_for_each_stage() + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + self.gamma = gamma + + self.last_sample = None + self._step_index = None + self._begin_index = None + + def init_sigmas(self): + """ + initialize the global timesteps and sigmas + """ + num_train_timesteps = self.config.num_train_timesteps + shift = self.config.shift + + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1) + sigmas = 1.0 - alphas + sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy() + sigmas = torch.from_numpy(sigmas) + timesteps = (sigmas * num_train_timesteps).clone() + + self._step_index = None + self._begin_index = None + self.timesteps = timesteps + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + + def init_sigmas_for_each_stage(self): + """ + Init the timesteps for each stage + """ + self.init_sigmas() + + stage_distance = [] + stages = self.config.stages + training_steps = self.config.num_train_timesteps + stage_range = self.config.stage_range + + # Init the start and end point of each stage + for i_s in range(stages): + # To decide the start and ends point + start_indice = int(stage_range[i_s] * training_steps) + start_indice = max(start_indice, 0) + end_indice = int(stage_range[i_s + 1] * training_steps) + end_indice = min(end_indice, training_steps) + start_sigma = self.sigmas[start_indice].item() + end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0 + self.ori_start_sigmas[i_s] = start_sigma + + if i_s != 0: + ori_sigma = 1 - start_sigma + gamma = self.config.gamma + corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma + # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma + start_sigma = 1 - corrected_sigma + + stage_distance.append(start_sigma - end_sigma) + self.start_sigmas[i_s] = start_sigma + self.end_sigmas[i_s] = end_sigma + + # Determine the ratio of each stage according to flow length + tot_distance = sum(stage_distance) + for i_s in range(stages): + if i_s == 0: + start_ratio = 0.0 + else: + start_ratio = sum(stage_distance[:i_s]) / tot_distance + if i_s == stages - 1: + end_ratio = 0.9999999999999999 + else: + end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance + + self.timestep_ratios[i_s] = (start_ratio, end_ratio) + + # Determine the timesteps and sigmas for each stage + for i_s in range(stages): + timestep_ratio = self.timestep_ratios[i_s] + # timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)] + timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999) + timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)] + timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1) + self.timesteps_per_stage[i_s] = ( + timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) + ) + stage_sigmas = np.linspace(0.999, 0, training_steps + 1) + self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def set_timesteps( + self, + num_inference_steps: int, + stage_index: int | None = None, + device: str | torch.device = None, + sigmas: bool | None = None, + mu: bool | None = None, + is_amplify_first_chunk: bool = False, + ): + """ + Setting the timesteps and sigmas for each stage + """ + if is_amplify_first_chunk: + num_inference_steps = num_inference_steps * 2 + 1 + else: + num_inference_steps = num_inference_steps + 1 + + self.num_inference_steps = num_inference_steps + self.init_sigmas() + + if self.config.stages == 1: + if sigmas is None: + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype( + np.float32 + ) + if self.config.shift != 1.0: + assert not self.config.use_dynamic_shifting + sigmas = self.time_shift(self.config.shift, 1.0, sigmas) + timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = torch.from_numpy(sigmas) + else: + stage_timesteps = self.timesteps_per_stage[stage_index] + timesteps = np.linspace( + stage_timesteps[0].item(), + stage_timesteps[-1].item(), + num_inference_steps, + ) + + stage_sigmas = self.sigmas_per_stage[stage_index] + ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps) + sigmas = torch.from_numpy(ratios) + + self.timesteps = torch.from_numpy(timesteps).to(device=device) + self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device) + + self._step_index = None + self.reset_scheduler_history() + + self.timesteps = self.timesteps[:-1] + self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]]) + + if self.config.use_dynamic_shifting: + assert self.config.shift == 1.0 + self.sigmas = self.time_shift(mu, 1.0, self.sigmas) + if self.config.stages == 1: + self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps + else: + self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( + self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() + ) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + """ + Apply time shifting to the sigmas. + + Args: + mu (`float`): + The mu parameter for the time shift. + sigma (`float`): + The sigma parameter for the time shift. + t (`torch.Tensor`): + The input timesteps. + + Returns: + `torch.Tensor`: + The time-shifted timesteps. + """ + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential + def _time_shift_exponential(self, mu, sigma, t): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear + def _time_shift_linear(self, mu, sigma, t): + return mu / (mu + (1 / t - 1) ** sigma) + + # ---------------------------------- For DMD ---------------------------------- + def add_noise(self, original_samples, noise, timestep, sigmas, timesteps): + sigmas = sigmas.to(noise.device) + timesteps = timesteps.to(noise.device) + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps): + # use higher precision for calculations + original_dtype = flow_pred.dtype + device = flow_pred.device + flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps)) + + timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1) + x0_pred = xt - sigma_t * flow_pred + return x0_pred.to(original_dtype) + + def step( + self, + model_output: torch.FloatTensor, + timestep: float | torch.FloatTensor = None, + sample: torch.FloatTensor = None, + generator: torch.Generator | None = None, + return_dict: bool = True, + cur_sampling_step: int = 0, + dmd_noisy_tensor: torch.FloatTensor | None = None, + dmd_sigmas: torch.FloatTensor | None = None, + dmd_timesteps: torch.FloatTensor | None = None, + all_timesteps: torch.FloatTensor | None = None, + ) -> HeliosDMDSchedulerOutput | tuple: + pred_image_or_video = self.convert_flow_pred_to_x0( + flow_pred=model_output, + xt=sample, + timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + if cur_sampling_step < len(all_timesteps) - 1: + prev_sample = self.add_noise( + pred_image_or_video, + dmd_noisy_tensor, + torch.full( + (model_output.shape[0],), + all_timesteps[cur_sampling_step + 1], + dtype=torch.long, + device=model_output.device, + ), + sigmas=dmd_sigmas, + timesteps=dmd_timesteps, + ) + else: + prev_sample = pred_image_or_video + + if not return_dict: + return (prev_sample,) + + return HeliosDMDSchedulerOutput(prev_sample=prev_sample) + + def reset_scheduler_history(self): + self._step_index = None + self._begin_index = None + + def __len__(self): + return self.config.num_train_timesteps diff --git a/lightx2v/models/schedulers/helios/scheduler.py b/lightx2v/models/schedulers/helios/scheduler.py new file mode 100644 index 000000000..18355cbc8 --- /dev/null +++ b/lightx2v/models/schedulers/helios/scheduler.py @@ -0,0 +1,47 @@ +import torch + +from lightx2v.models.schedulers.helios.helios_dmd import HeliosDMDScheduler +from lightx2v.models.schedulers.scheduler import BaseScheduler +from lightx2v_platform.base.global_var import AI_DEVICE + + +class HeliosDistilledScheduler(BaseScheduler): + def __init__(self, config): + infer_steps = sum(config.get("pyramid_num_inference_steps_list", [2, 2, 2])) + config["infer_steps"] = infer_steps + super().__init__(config) + self.inner = HeliosDMDScheduler.from_pretrained(config["scheduler_path"]) + self.sample_guide_scale = config.get("sample_guide_scale", 1.0) + self.generator = None + self.latents = None + self.timesteps = None + self.sigmas = None + + def prepare(self, seed, latent_shape, image_encoder_output=None, generator=None): + self.generator = generator if generator is not None else torch.Generator(device=AI_DEVICE).manual_seed(seed) + self.latent_shape = latent_shape + self.image_encoder_output = image_encoder_output + + def set_timesteps(self, num_inference_steps, stage_idx, device, mu, is_amplify_first_chunk): + self.inner.set_timesteps( + num_inference_steps, + stage_idx, + device=device, + mu=mu, + is_amplify_first_chunk=is_amplify_first_chunk, + ) + self.timesteps = self.inner.timesteps + self.sigmas = self.inner.sigmas + self.infer_steps = len(self.timesteps) + self.step_index = 0 + + def step(self, *args, **kwargs): + return self.inner.step(*args, **kwargs) + + def step_pre(self, step_index): + self.step_index = step_index + + def clear(self): + self.latents = None + self.timesteps = None + self.sigmas = None diff --git a/lightx2v/models/video_encoders/hf/helios/__init__.py b/lightx2v/models/video_encoders/hf/helios/__init__.py new file mode 100644 index 000000000..cdefb6623 --- /dev/null +++ b/lightx2v/models/video_encoders/hf/helios/__init__.py @@ -0,0 +1,3 @@ +from lightx2v.models.video_encoders.hf.helios.vae import HeliosVAE + +__all__ = ["HeliosVAE"] diff --git a/lightx2v/models/video_encoders/hf/helios/vae.py b/lightx2v/models/video_encoders/hf/helios/vae.py new file mode 100644 index 000000000..7b761e2a8 --- /dev/null +++ b/lightx2v/models/video_encoders/hf/helios/vae.py @@ -0,0 +1,67 @@ +import gc +import os + +import torch +from diffusers import AutoencoderKLWan +from diffusers.utils import load_image +from diffusers.video_processor import VideoProcessor + +from lightx2v.utils.envs import GET_DTYPE +from lightx2v_platform.base.global_var import AI_DEVICE + + +class HeliosVAE: + def __init__(self, config): + self.config = config + self.cpu_offload = config.get("vae_cpu_offload", config.get("cpu_offload", False)) + self.device = torch.device("cpu") if self.cpu_offload else torch.device(AI_DEVICE) + self.dtype = torch.float32 + self.model = AutoencoderKLWan.from_pretrained(config["vae_path"], torch_dtype=self.dtype).to(self.device) + self.vae_scale_factor_temporal = getattr(self.model.config, "scale_factor_temporal", 4) + self.vae_scale_factor_spatial = getattr(self.model.config, "scale_factor_spatial", 8) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.latents_mean = torch.tensor(self.model.config.latents_mean).view(1, self.model.config.z_dim, 1, 1, 1) + self.latents_std = 1.0 / torch.tensor(self.model.config.latents_std).view(1, self.model.config.z_dim, 1, 1, 1) + if config.get("use_tiling_vae", False): + self.model.enable_tiling() + + def _to_device(self): + if self.cpu_offload: + self.model.to(torch.device(AI_DEVICE)) + + def _to_cpu(self): + if self.cpu_offload: + self.model.to(torch.device("cpu")) + torch.cuda.empty_cache() + gc.collect() + + def preprocess_image(self, image_path_or_pil, height, width): + image = image_path_or_pil + if isinstance(image, (str, os.PathLike)): + image = load_image(str(image)) + return self.video_processor.preprocess(image, height=height, width=width) + + def prepare_image_latents(self, image, generator, num_latent_frames_per_chunk, height, width, dtype=torch.float32): + self._to_device() + image = self.preprocess_image(image, height, width).unsqueeze(2).to(device=self.model.device, dtype=self.model.dtype) + latents_mean = self.latents_mean.to(device=self.model.device, dtype=self.model.dtype) + latents_std = self.latents_std.to(device=self.model.device, dtype=self.model.dtype) + image_latents = self.model.encode(image).latent_dist.sample(generator=generator) + image_latents = (image_latents - latents_mean) * latents_std + + min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1 + fake_video = image.repeat(1, 1, min_frames, 1, 1) + fake_latents = self.model.encode(fake_video).latent_dist.sample(generator=generator) + fake_latents = (fake_latents - latents_mean) * latents_std + fake_latents = fake_latents[:, :, -1:, :, :] + self._to_cpu() + return image_latents.to(dtype=dtype), fake_latents.to(dtype=dtype) + + def decode(self, latents): + self._to_device() + latents_mean = self.latents_mean.to(device=latents.device, dtype=latents.dtype) + latents_std = self.latents_std.to(device=latents.device, dtype=latents.dtype) + current_latents = latents.to(self.model.device, dtype=self.model.dtype) / latents_std + latents_mean + decoded = self.model.decode(current_latents, return_dict=False)[0] + self._to_cpu() + return decoded diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index aca72e580..247b690f9 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -14,6 +14,7 @@ from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 except (ImportError, ModuleNotFoundError) as e: logger.warning(f"Flux2 runners not available: {e}") +from lightx2v.models.runners.helios.helios_runner import HeliosRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner # noqa: F401 @@ -116,6 +117,9 @@ def __init__( elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: self.vae_stride = (4, 16, 16) self.num_channels_latents = 32 + elif self.model_cls in ["helios"]: + self.vae_stride = (4, 8, 8) + self.num_channels_latents = 16 elif self.model_cls in ["ltx2"]: self.num_channels_latents = 128 self.audio_mel_bins = 16 @@ -248,7 +252,7 @@ def set_infer_config( self.self_attn_1_type = attn_mode self.cross_attn_1_type = attn_mode self.cross_attn_2_type = attn_mode - elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill", "qwen_image", "longcat_image", "ltx2", "z_image"]: + elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill", "qwen_image", "longcat_image", "ltx2", "z_image", "helios"]: self.attn_type = attn_mode self.norm_modulate_backend = norm_modulate_backend @@ -304,6 +308,10 @@ def enable_quantize( self.qwen25vl_quantized = text_encoder_quantized self.qwen25vl_quantized_ckpt = text_encoder_quantized_ckpt self.qwen25vl_quant_scheme = text_encoder_quant_scheme + elif self.model_cls == "helios": + self.text_encoder_quantized = text_encoder_quantized + self.text_encoder_quantized_ckpt = text_encoder_quantized_ckpt + self.text_encoder_quant_scheme = text_encoder_quant_scheme elif self.model_cls in ["ltx2"]: self.skip_fp8_block_index = skip_fp8_block_index elif self.model_cls == "z_image": @@ -345,6 +353,8 @@ def enable_offload( self.qwen25vl_cpu_offload = text_encoder_offload self.siglip_cpu_offload = image_encoder_offload self.byt5_cpu_offload = image_encoder_offload + elif self.model_cls == "helios": + self.text_encoder_cpu_offload = text_encoder_offload elif self.model_cls in ["qwen_image", "longcat_image"]: self.qwen25vl_cpu_offload = text_encoder_offload elif self.model_cls == "ltx2": diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index 9d14cf7ae..f90faa1d2 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -121,6 +121,38 @@ def auto_calc_config(config): config["config_path"] = wm_config if wm_ckpt: config["ckpt_path"] = wm_ckpt + elif config["model_cls"] == "helios": + transformer_path = os.path.join(config["model_path"], "transformer") + config["transformer_model_path"] = transformer_path + config["text_encoder_path"] = config.get("text_encoder_path", os.path.join(config["model_path"], "text_encoder")) + config["tokenizer_path"] = config.get("tokenizer_path", os.path.join(config["model_path"], "tokenizer")) + config["vae_path"] = config.get("vae_path", os.path.join(config["model_path"], "vae")) + config["scheduler_path"] = config.get("scheduler_path", os.path.join(config["model_path"], "scheduler")) + config["max_sequence_length"] = config.get("max_sequence_length", 512) + + model_index_path = os.path.join(config["model_path"], "model_index.json") + if os.path.exists(model_index_path): + with open(model_index_path, "r", encoding="utf-8") as f: + model_index = json.load(f) + config["is_distilled"] = bool(model_index.get("is_distilled", config.get("is_distilled", False))) + config["model_variant"] = config.get("model_variant", "distilled" if config["is_distilled"] else "base") + scheduler_entry = model_index.get("scheduler") + if isinstance(scheduler_entry, list) and len(scheduler_entry) >= 2: + config["scheduler_type"] = scheduler_entry[1] + + if os.path.exists(os.path.join(transformer_path, "config.json")): + with open(os.path.join(transformer_path, "config.json"), "r", encoding="utf-8") as f: + model_config = json.load(f) + config.update(model_config) + + scheduler_config_path = os.path.join(config["model_path"], "scheduler", "scheduler_config.json") + if os.path.exists(scheduler_config_path): + with open(scheduler_config_path, "r", encoding="utf-8") as f: + scheduler_config = json.load(f) + config["scheduler_type"] = scheduler_config.get("_class_name", config.get("scheduler_type")) + for key in ("stages", "stage_range", "shift", "prediction_type", "time_shift_type", "use_dynamic_shifting", "use_flow_sigmas"): + if key in scheduler_config: + config[key] = scheduler_config[key] elif config["model_cls"] == "longcat_image": # Special config for longcat_image: load both root and transformer config if os.path.exists(os.path.join(config["model_path"], "config.json")): with open(os.path.join(config["model_path"], "config.json"), "r") as f: diff --git a/lightx2v/utils/utils.py b/lightx2v/utils/utils.py index 3b2993517..5e0826700 100755 --- a/lightx2v/utils/utils.py +++ b/lightx2v/utils/utils.py @@ -736,6 +736,14 @@ def validate_config_paths(config: dict) -> None: check_path_exists(config["dit_original_ckpt"]) logger.debug(f"✓ Verified dit_original_ckpt: {config['dit_original_ckpt']}") + if config.get("model_cls") == "helios": + for key in ("transformer_model_path", "text_encoder_path", "tokenizer_path", "vae_path", "scheduler_path"): + if key in config and config[key] is not None: + check_path_exists(config[key]) + logger.debug(f"✓ Verified {key}: {config[key]}") + logger.info("✓ Config checkpoint paths validated successfully") + return + # For wan2.2, check high and low noise checkpoints model_cls = config.get("model_cls", "") if model_cls and "wan2.2" in model_cls: diff --git a/scripts/helios/run_helios_distilled_i2v.sh b/scripts/helios/run_helios_distilled_i2v.sh new file mode 100644 index 000000000..6c651c795 --- /dev/null +++ b/scripts/helios/run_helios_distilled_i2v.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +lightx2v_path= +model_path=/data1/models/BestWishYSH/Helios-Distilled +image_path= + +export CUDA_VISIBLE_DEVICES=0 + +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls helios \ +--task i2v \ +--model_path ${model_path} \ +--config_json ${lightx2v_path}/configs/helios/helios_distilled_i2v.json \ +--image_path ${image_path} \ +--prompt "The scene comes alive with subtle camera motion and realistic atmospheric movement." \ +--negative_prompt "overexposed, blurry, low quality, jpeg artifacts, static frame, distorted anatomy, extra limbs" \ +--save_result_path ${lightx2v_path}/save_results/output_helios_distilled_i2v.mp4 diff --git a/scripts/helios/run_helios_distilled_t2v.sh b/scripts/helios/run_helios_distilled_t2v.sh new file mode 100644 index 000000000..3b94d811e --- /dev/null +++ b/scripts/helios/run_helios_distilled_t2v.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +lightx2v_path= +model_path=/data1/models/BestWishYSH/Helios-Distilled + +export CUDA_VISIBLE_DEVICES=0 + +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls helios \ +--task t2v \ +--model_path ${model_path} \ +--config_json ${lightx2v_path}/configs/helios/helios_distilled_t2v.json \ +--prompt "A cinematic close-up of a snow leopard walking across a windy ridge at sunrise, detailed fur moving naturally in the light." \ +--negative_prompt "overexposed, blurry, low quality, jpeg artifacts, static frame, distorted anatomy, extra limbs" \ +--save_result_path ${lightx2v_path}/save_results/output_helios_distilled_t2v.mp4 diff --git a/test_cases/test_helios_consistency_helpers.py b/test_cases/test_helios_consistency_helpers.py new file mode 100644 index 000000000..7a8f2669b --- /dev/null +++ b/test_cases/test_helios_consistency_helpers.py @@ -0,0 +1,195 @@ +import importlib.util +import sys +import types +import unittest +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +def load_module(module_name: str, file_path: Path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +class HeliosPromptPackingTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + transformers_module = types.ModuleType("transformers") + transformers_module.AutoTokenizer = object + transformers_module.UMT5EncoderModel = object + sys.modules.setdefault("transformers", transformers_module) + + envs_module = types.ModuleType("lightx2v.utils.envs") + envs_module.GET_DTYPE = lambda: torch.bfloat16 + sys.modules.setdefault("lightx2v.utils.envs", envs_module) + + global_var_module = types.ModuleType("lightx2v_platform.base.global_var") + global_var_module.AI_DEVICE = "cpu" + sys.modules.setdefault("lightx2v_platform.base.global_var", global_var_module) + + cls.text_module = load_module( + "test_helios_text_model", + REPO_ROOT / "lightx2v/models/input_encoders/hf/helios/model.py", + ) + + def test_pack_prompt_embeds_reapplies_sequence_lengths_before_padding(self): + hidden_state = torch.tensor( + [ + [[1.0, 10.0], [2.0, 20.0], [999.0, 999.0], [999.0, 999.0]], + [[3.0, 30.0], [4.0, 40.0], [5.0, 50.0], [999.0, 999.0]], + ] + ) + attention_mask = torch.tensor( + [ + [1, 1, 0, 0], + [1, 1, 1, 0], + ] + ) + + prompt_embeds, mask = self.text_module.pack_t5_prompt_embeds( + hidden_state, + attention_mask, + max_sequence_length=4, + num_videos_per_prompt=2, + dtype=torch.bfloat16, + device=torch.device("cpu"), + ) + + self.assertEqual(tuple(prompt_embeds.shape), (4, 4, 2)) + self.assertEqual(prompt_embeds.dtype, torch.bfloat16) + self.assertTrue(mask.dtype == torch.bool) + self.assertTrue(torch.equal(mask, attention_mask.bool())) + self.assertTrue(torch.equal(prompt_embeds[0, 2:], torch.zeros((2, 2), dtype=torch.bfloat16))) + self.assertTrue(torch.equal(prompt_embeds[1], prompt_embeds[0])) + self.assertTrue(torch.equal(prompt_embeds[2, 3:], torch.zeros((1, 2), dtype=torch.bfloat16))) + self.assertEqual(prompt_embeds[2, 2, 0].item(), 5.0) + + +class HeliosRuntimeUtilsTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.runtime_utils = load_module( + "test_helios_runtime_utils", + REPO_ROOT / "lightx2v/models/runners/helios/runtime_utils.py", + ) + + def test_apply_image_condition_noise_uses_distinct_sigmas_and_generator_order(self): + image_latents = torch.ones((1, 1, 1, 1, 1), dtype=torch.float32) + fake_image_latents = torch.full((1, 1, 1, 1, 1), 2.0, dtype=torch.float32) + generator = torch.Generator(device="cpu").manual_seed(123) + + noisy_image, noisy_fake = self.runtime_utils.apply_image_condition_noise( + image_latents=image_latents, + fake_image_latents=fake_image_latents, + generator=generator, + device=torch.device("cpu"), + image_noise_sigma_min=0.111, + image_noise_sigma_max=0.135, + video_noise_sigma_min=0.211, + video_noise_sigma_max=0.235, + ) + + ref_generator = torch.Generator(device="cpu").manual_seed(123) + image_sigma = torch.rand(1, device="cpu", generator=ref_generator) * (0.135 - 0.111) + 0.111 + ref_noisy_image = image_sigma * torch.randn(image_latents.shape, generator=ref_generator) + (1 - image_sigma) * image_latents + fake_sigma = torch.rand(1, device="cpu", generator=ref_generator) * (0.235 - 0.211) + 0.211 + ref_noisy_fake = fake_sigma * torch.randn(fake_image_latents.shape, generator=ref_generator) + (1 - fake_sigma) * fake_image_latents + + self.assertTrue(torch.allclose(noisy_image, ref_noisy_image)) + self.assertTrue(torch.allclose(noisy_fake, ref_noisy_fake)) + + def test_trim_and_postprocess_video_matches_helios_frame_rule(self): + history_video = torch.arange(1 * 3 * 99 * 2 * 2, dtype=torch.float32).reshape(1, 3, 99, 2, 2) + + class DummyVideoProcessor: + def __init__(self): + self.called = False + self.last_shape = None + + def postprocess_video(self, video, output_type="np"): + self.called = True + self.last_shape = tuple(video.shape) + return {"frames": video.clone(), "output_type": output_type} + + processor = DummyVideoProcessor() + result = self.runtime_utils.finalize_video_output( + history_video=history_video, + video_processor=processor, + temporal_scale_factor=4, + output_type="np", + ) + + self.assertTrue(processor.called) + self.assertEqual(processor.last_shape, (1, 3, 97, 2, 2)) + self.assertEqual(tuple(result["frames"].shape), (1, 3, 97, 2, 2)) + self.assertEqual(result["output_type"], "np") + + def test_pt_video_output_is_converted_to_comfy_frame_layout(self): + pt_video = torch.arange(1 * 2 * 3 * 2 * 2, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + frames = self.runtime_utils.pt_video_output_to_comfy_frames(pt_video) + self.assertEqual(tuple(frames.shape), (2, 2, 2, 3)) + self.assertTrue(torch.equal(frames[0, 0, 0], torch.tensor([0.0, 4.0, 8.0]))) + + +class HeliosI2VGeneratorContinuityTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + base_scheduler_module = types.ModuleType("lightx2v.models.schedulers.scheduler") + + class BaseScheduler: + def __init__(self, config): + self.config = config + + base_scheduler_module.BaseScheduler = BaseScheduler + sys.modules.setdefault("lightx2v.models.schedulers.scheduler", base_scheduler_module) + + fake_dmd_module = types.ModuleType("lightx2v.models.schedulers.helios.helios_dmd") + + class FakeInnerScheduler: + config = types.SimpleNamespace() + + @classmethod + def from_pretrained(cls, _path): + return cls() + + fake_dmd_module.HeliosDMDScheduler = FakeInnerScheduler + sys.modules.setdefault("lightx2v.models.schedulers.helios.helios_dmd", fake_dmd_module) + + global_var_module = types.ModuleType("lightx2v_platform.base.global_var") + global_var_module.AI_DEVICE = "cpu" + sys.modules["lightx2v_platform.base.global_var"] = global_var_module + + cls.scheduler_module = load_module( + "test_helios_scheduler_module", + REPO_ROOT / "lightx2v/models/schedulers/helios/scheduler.py", + ) + + def test_prepare_reuses_external_generator_for_i2v_rng_continuity(self): + scheduler = self.scheduler_module.HeliosDistilledScheduler( + { + "scheduler_path": "/tmp/unused", + "pyramid_num_inference_steps_list": [2, 2, 2], + "sample_guide_scale": 1.0, + } + ) + external_generator = torch.Generator(device="cpu").manual_seed(42) + scheduler.prepare(seed=999, latent_shape=[16, 25, 48, 80], image_encoder_output={}, generator=external_generator) + self.assertIs(scheduler.generator, external_generator) + + def test_helios_runner_i2v_no_longer_reseeds_a_second_generator(self): + runner_path = REPO_ROOT / "lightx2v/models/runners/helios/helios_runner.py" + source = runner_path.read_text(encoding="utf-8") + i2v_block = source.split("def _run_input_encoder_local_i2v", 1)[1].split("def sample_block_noise", 1)[0] + self.assertNotIn("manual_seed(self.input_info.seed)", i2v_block) + + +if __name__ == "__main__": + unittest.main() diff --git a/test_cases/test_helios_distilled_support.py b/test_cases/test_helios_distilled_support.py new file mode 100644 index 000000000..5af5c1815 --- /dev/null +++ b/test_cases/test_helios_distilled_support.py @@ -0,0 +1,198 @@ +import argparse +import importlib.machinery +import os +import sys +import tempfile +import types +import unittest + + +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +APP_ROOT = os.path.join(REPO_ROOT, "app") + +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) +if APP_ROOT not in sys.path: + sys.path.insert(0, APP_ROOT) + +lightx2v_pkg = types.ModuleType("lightx2v") +lightx2v_pkg.__path__ = [os.path.join(REPO_ROOT, "lightx2v")] +sys.modules.setdefault("lightx2v", lightx2v_pkg) + +lightx2v_platform_pkg = types.ModuleType("lightx2v_platform") +lightx2v_platform_pkg.__path__ = [os.path.join(REPO_ROOT, "lightx2v_platform")] +sys.modules.setdefault("lightx2v_platform", lightx2v_platform_pkg) + +lightx2v_platform_base_pkg = types.ModuleType("lightx2v_platform.base") +lightx2v_platform_base_pkg.__path__ = [os.path.join(REPO_ROOT, "lightx2v_platform", "base")] +sys.modules.setdefault("lightx2v_platform.base", lightx2v_platform_base_pkg) + +global_var_module = types.ModuleType("lightx2v_platform.base.global_var") +global_var_module.AI_DEVICE = "cuda" +sys.modules.setdefault("lightx2v_platform.base.global_var", global_var_module) + +lightx2v_utils_module = types.ModuleType("lightx2v.utils.utils") +lightx2v_utils_module.is_main_process = lambda: True +sys.modules.setdefault("lightx2v.utils.utils", lightx2v_utils_module) + +if "loguru" not in sys.modules: + sys.modules["loguru"] = types.SimpleNamespace( + logger=types.SimpleNamespace(info=lambda *a, **k: None, warning=lambda *a, **k: None, debug=lambda *a, **k: None) + ) +if "psutil" not in sys.modules: + sys.modules["psutil"] = types.SimpleNamespace(virtual_memory=lambda: types.SimpleNamespace(available=0)) +if "huggingface_hub" not in sys.modules: + hf_module = types.ModuleType("huggingface_hub") + hf_module.HfApi = object + hf_module.list_repo_files = lambda *a, **k: [] + hf_module.__spec__ = importlib.machinery.ModuleSpec("huggingface_hub", loader=None) + sys.modules["huggingface_hub"] = hf_module +if "modelscope" not in sys.modules: + modelscope_module = types.ModuleType("modelscope") + modelscope_module.__spec__ = importlib.machinery.ModuleSpec("modelscope", loader=None) + hub_module = types.ModuleType("modelscope.hub") + hub_module.__spec__ = importlib.machinery.ModuleSpec("modelscope.hub", loader=None) + api_module = types.ModuleType("modelscope.hub.api") + api_module.__spec__ = importlib.machinery.ModuleSpec("modelscope.hub.api", loader=None) + api_module.HubApi = object + sys.modules["modelscope"] = modelscope_module + sys.modules["modelscope.hub"] = hub_module + sys.modules["modelscope.hub.api"] = api_module +if "torch" not in sys.modules: + torch_module = types.ModuleType("torch") + torch_module.__spec__ = importlib.machinery.ModuleSpec("torch", loader=None) + torch_module.float16 = "float16" + torch_module.float32 = "float32" + torch_module.bfloat16 = "bfloat16" + torch_module.Tensor = object + torch_module._scaled_mm = object() + torch_module.cuda = types.SimpleNamespace( + is_available=lambda: False, + get_device_capability=lambda *_: (0, 0), + get_device_name=lambda *_: "", + empty_cache=lambda: None, + synchronize=lambda: None, + ) + torch_module.device = lambda value: value + + dist_module = types.ModuleType("torch.distributed") + dist_module.is_initialized = lambda: False + dist_module.get_rank = lambda: 0 + dist_module.get_world_size = lambda: 1 + dist_module.all_reduce = lambda *_args, **_kwargs: None + + tensor_module = types.ModuleType("torch.distributed.tensor") + device_mesh_module = types.ModuleType("torch.distributed.tensor.device_mesh") + device_mesh_module.init_device_mesh = lambda *_args, **_kwargs: None + + torch_module.distributed = dist_module + sys.modules["torch"] = torch_module + sys.modules["torch.distributed"] = dist_module + sys.modules["torch.distributed.tensor"] = tensor_module + sys.modules["torch.distributed.tensor.device_mesh"] = device_mesh_module + +from utils.model_utils import get_model_configs +from lightx2v.utils.set_config import set_config + + +class HeliosDistilledSupportTest(unittest.TestCase): + def test_get_model_configs_detects_helios_distilled_variant(self): + config = get_model_configs( + model_type_input="Helios", + model_path_input="/data1/models/BestWishYSH/Helios-Distilled", + dit_path_input=None, + high_noise_path_input=None, + low_noise_path_input=None, + t5_path_input=None, + clip_path_input=None, + vae_path_input=None, + qwen_image_dit_path_input=None, + qwen_image_vae_path_input=None, + qwen_image_scheduler_path_input=None, + qwen25vl_encoder_path_input=None, + z_image_dit_path_input=None, + z_image_vae_path_input=None, + z_image_scheduler_path_input=None, + qwen3_encoder_path_input=None, + quant_op="triton", + ) + + self.assertEqual(config["model_cls"], "helios") + self.assertEqual(config["model_variant"], "distilled") + self.assertEqual(config["scheduler_type"], "HeliosDMDScheduler") + self.assertEqual(config["model_path"], "/data1/models/BestWishYSH/Helios-Distilled") + self.assertEqual(config["transformer_model_path"], "/data1/models/BestWishYSH/Helios-Distilled/transformer") + self.assertEqual(config["text_encoder_path"], "/data1/models/BestWishYSH/Helios-Distilled/text_encoder") + self.assertEqual(config["tokenizer_path"], "/data1/models/BestWishYSH/Helios-Distilled/tokenizer") + self.assertEqual(config["vae_path"], "/data1/models/BestWishYSH/Helios-Distilled/vae") + self.assertEqual(config["scheduler_path"], "/data1/models/BestWishYSH/Helios-Distilled/scheduler") + self.assertTrue(config["is_distilled"]) + + def test_set_config_loads_helios_transformer_and_scheduler_metadata(self): + with tempfile.TemporaryDirectory() as tmpdir: + model_root = os.path.join(tmpdir, "Helios-Distilled") + os.makedirs(os.path.join(model_root, "transformer")) + os.makedirs(os.path.join(model_root, "scheduler")) + os.makedirs(os.path.join(model_root, "text_encoder")) + os.makedirs(os.path.join(model_root, "tokenizer")) + os.makedirs(os.path.join(model_root, "vae")) + + with open(os.path.join(model_root, "configuration.json"), "w", encoding="utf-8") as f: + f.write('{"model_type": "helios"}') + with open(os.path.join(model_root, "model_index.json"), "w", encoding="utf-8") as f: + f.write( + '{"_class_name":"HeliosPyramidPipeline","is_distilled":true,' + '"scheduler":["diffusers","HeliosDMDScheduler"],' + '"transformer":["diffusers","HeliosTransformer3DModel"],' + '"text_encoder":["transformers","UMT5EncoderModel"],' + '"tokenizer":["transformers","T5TokenizerFast"],' + '"vae":["diffusers","AutoencoderKLWan"]}' + ) + with open(os.path.join(model_root, "transformer", "config.json"), "w", encoding="utf-8") as f: + f.write('{"num_layers": 40, "patch_size": [1, 2, 2], "in_channels": 16, "out_channels": 16}') + with open(os.path.join(model_root, "scheduler", "scheduler_config.json"), "w", encoding="utf-8") as f: + f.write('{"_class_name":"HeliosDMDScheduler","stages":3}') + with open(os.path.join(model_root, "vae", "config.json"), "w", encoding="utf-8") as f: + f.write('{"temperal_downsample":[false,true,true]}') + + args = argparse.Namespace( + model_cls="helios", + model_variant="distilled", + task="t2v", + model_path=model_root, + target_video_length=99, + ) + + config = set_config(args) + self.assertEqual(config["scheduler_type"], "HeliosDMDScheduler") + self.assertEqual(config["num_layers"], 40) + self.assertEqual(config["patch_size"], [1, 2, 2]) + self.assertEqual(config["vae_scale_factor"], 8) + self.assertTrue(config["is_distilled"]) + + def test_helios_runner_is_native_not_pipeline_bridge(self): + runner_path = os.path.join(REPO_ROOT, "lightx2v", "models", "runners", "helios", "helios_runner.py") + with open(runner_path, "r", encoding="utf-8") as f: + source = f.read() + + self.assertIn("class HeliosRunner", source) + self.assertNotIn("HeliosPyramidPipeline", source) + self.assertNotIn("HeliosPipeline", source) + + def test_infer_cli_exposes_helios_model_cls(self): + infer_path = os.path.join(REPO_ROOT, "lightx2v", "infer.py") + with open(infer_path, "r", encoding="utf-8") as f: + source = f.read() + + self.assertIn('"helios"', source) + + def test_validate_config_paths_has_helios_branch(self): + utils_path = os.path.join(REPO_ROOT, "lightx2v", "utils", "utils.py") + with open(utils_path, "r", encoding="utf-8") as f: + source = f.read() + + self.assertIn('config.get("model_cls") == "helios"', source) + + +if __name__ == "__main__": + unittest.main() From a304320d43705e8d916812362d339d81a0a77412 Mon Sep 17 00:00:00 2001 From: xlycae Date: Fri, 29 May 2026 14:24:04 +0800 Subject: [PATCH 2/6] refactor: tighten helios distilled integration --- app/utils/model_utils.py | 67 ---------------- configs/helios/helios_distilled_i2v.json | 2 +- configs/helios/helios_distilled_t2v.json | 2 +- lightx2v/infer.py | 79 ++++++++++--------- .../models/runners/helios/helios_runner.py | 67 ++++++++++++++-- .../models/runners/helios/runtime_utils.py | 43 ---------- .../models/schedulers/helios/scheduler.py | 5 ++ lightx2v/pipeline.py | 9 +-- lightx2v/utils/set_config.py | 25 +++++- lightx2v/utils/utils.py | 2 +- 10 files changed, 136 insertions(+), 165 deletions(-) delete mode 100644 lightx2v/models/runners/helios/runtime_utils.py diff --git a/app/utils/model_utils.py b/app/utils/model_utils.py index f3b8055e5..d085a3732 100644 --- a/app/utils/model_utils.py +++ b/app/utils/model_utils.py @@ -472,68 +472,6 @@ def get_quant_scheme(quant_detected, quant_op_val): return f"{quant_detected}-{quant_op_val}" -def _load_json_if_exists(path): - if path and os.path.exists(path): - with open(path, "r", encoding="utf-8") as f: - return json.load(f) - return None - - -def detect_helios_variant(model_path_input): - model_index = _load_json_if_exists(os.path.join(model_path_input, "model_index.json")) or {} - modular_model_index = _load_json_if_exists(os.path.join(model_path_input, "modular_model_index.json")) or {} - - scheduler_name = "" - scheduler_entry = model_index.get("scheduler") - if isinstance(scheduler_entry, list) and len(scheduler_entry) >= 2: - scheduler_name = scheduler_entry[1] - - if not scheduler_name: - scheduler_entry = modular_model_index.get("scheduler") - if isinstance(scheduler_entry, list) and len(scheduler_entry) >= 2: - scheduler_name = scheduler_entry[1] - - is_distilled = bool(model_index.get("is_distilled")) or "Distilled" in (modular_model_index.get("_class_name") or "") or scheduler_name == "HeliosDMDScheduler" - variant = "distilled" if is_distilled else "base" - return variant, scheduler_name or ("HeliosDMDScheduler" if is_distilled else "HeliosScheduler"), model_index, modular_model_index - - -def build_helios(model_path_input): - variant, scheduler_type, model_index, modular_model_index = detect_helios_variant(model_path_input) - transformer_config = _load_json_if_exists(os.path.join(model_path_input, "transformer", "config.json")) or {} - scheduler_config = _load_json_if_exists(os.path.join(model_path_input, "scheduler", "scheduler_config.json")) or {} - - helios_config = { - "model_cls": "helios", - "model_variant": variant, - "is_distilled": variant == "distilled", - "model_path": model_path_input, - "transformer_model_path": os.path.join(model_path_input, "transformer"), - "text_encoder_path": os.path.join(model_path_input, "text_encoder"), - "tokenizer_path": os.path.join(model_path_input, "tokenizer"), - "vae_path": os.path.join(model_path_input, "vae"), - "scheduler_path": os.path.join(model_path_input, "scheduler"), - "scheduler_type": scheduler_type, - "model_index_class": model_index.get("_class_name") or modular_model_index.get("_class_name"), - "guider_config_path": os.path.join(model_path_input, "guider", "guider_config.json"), - "transformer_ode_model_path": os.path.join(model_path_input, "transformer_ode"), - "history_sizes": [16, 2, 1], - "num_latent_frames_per_chunk": 9, - "use_zero_init": False, - "zero_steps": 1, - "is_enable_stage2": False, - "pyramid_num_inference_steps_list": [20, 20, 20], - "is_skip_first_chunk": False, - "is_amplify_first_chunk": False, - "image_noise_sigma_min": 0.111, - "image_noise_sigma_max": 0.135, - "use_dynamic_shifting": scheduler_config.get("use_dynamic_shifting"), - "use_flow_sigmas": scheduler_config.get("use_flow_sigmas"), - } - helios_config.update(transformer_config) - return helios_config - - def build_wan21( model_path_input, dit_path_input, @@ -1008,8 +946,3 @@ def get_model_configs( if lora_configs: config["lora_configs"] = lora_configs return config - elif model_type_input == "Helios": - config = build_helios(model_path_input) - if lora_configs: - config["lora_configs"] = lora_configs - return config diff --git a/configs/helios/helios_distilled_i2v.json b/configs/helios/helios_distilled_i2v.json index bfa00b210..8dae8c05f 100644 --- a/configs/helios/helios_distilled_i2v.json +++ b/configs/helios/helios_distilled_i2v.json @@ -1,5 +1,5 @@ { - "model_cls": "helios", + "model_cls": "helios_distilled", "model_variant": "distilled", "infer_steps": 6, "target_video_length": 99, diff --git a/configs/helios/helios_distilled_t2v.json b/configs/helios/helios_distilled_t2v.json index b10fec7b3..b63aebe4c 100644 --- a/configs/helios/helios_distilled_t2v.json +++ b/configs/helios/helios_distilled_t2v.json @@ -1,5 +1,5 @@ { - "model_cls": "helios", + "model_cls": "helios_distilled", "model_variant": "distilled", "infer_steps": 6, "target_video_length": 99, diff --git a/lightx2v/infer.py b/lightx2v/infer.py index 71adc40e0..6a9998f58 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -40,6 +40,45 @@ from lightx2v.utils.utils import seed_all, validate_config_paths from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER +SUPPORTED_MODEL_CLASSES = [ + "wan2.1", + "wan2.1_distill", + "wan2.1_mean_flow_distill", + "wan2.1_vace", + "wan2.1_sf", + "wan2.1_sf_mtxg2", + "seko_talk", + "seko_talk_ar", + "wan2.2_moe", + "lingbot_world", + "wan2.2", + "wan2.2_matrix_game3", + "wan2.2_moe_audio", + "wan2.2_audio", + "wan2.2_moe_distill", + "wan2.2_moe_vace", + "qwen_image", + "longcat_image", + "wan2.2_animate", + "hunyuan_video_1.5", + "hunyuan_video_1.5_distill", + "helios_distilled", + "hunyuan3d", + "worldplay_distill", + "worldplay_ar", + "worldplay_bi", + "z_image", + "flux2_klein", + "flux2_dev", + "ltx2", + "bagel", + "seedvr2", + "neopp", + "motus", + "lingbot_world_fast", + "worldmirror", +] + def init_runner(config): torch.set_grad_enabled(False) @@ -55,44 +94,6 @@ def main(): "--model_cls", type=str, required=True, - choices=[ - "wan2.1", - "wan2.1_distill", - "wan2.1_mean_flow_distill", - "wan2.1_vace", - "wan2.1_sf", - "wan2.1_sf_mtxg2", - "seko_talk", - "seko_talk_ar", - "wan2.2_moe", - "lingbot_world", - "wan2.2", - "wan2.2_matrix_game3", - "wan2.2_moe_audio", - "wan2.2_audio", - "wan2.2_moe_distill", - "wan2.2_moe_vace", - "qwen_image", - "longcat_image", - "wan2.2_animate", - "hunyuan_video_1.5", - "hunyuan_video_1.5_distill", - "helios", - "hunyuan3d", - "worldplay_distill", - "worldplay_ar", - "worldplay_bi", - "z_image", - "flux2_klein", - "flux2_dev", - "ltx2", - "bagel", - "seedvr2", - "neopp", - "motus", - "lingbot_world_fast", - "worldmirror", - ], default="wan2.1", ) @@ -223,6 +224,8 @@ def main(): parser.add_argument("--mux_audio_video_path", type=str, default=None, help="(v2av, optional) After saving, mux audio from this file into the output mp4 (ffmpeg). ") args = parser.parse_args() + if args.model_cls not in SUPPORTED_MODEL_CLASSES: + parser.error(f"invalid --model_cls '{args.model_cls}'. Supported values: {', '.join(SUPPORTED_MODEL_CLASSES)}") # validate_task_arguments(args) seed_all(args.seed) diff --git a/lightx2v/models/runners/helios/helios_runner.py b/lightx2v/models/runners/helios/helios_runner.py index 14504182d..ad25bed02 100644 --- a/lightx2v/models/runners/helios/helios_runner.py +++ b/lightx2v/models/runners/helios/helios_runner.py @@ -9,7 +9,6 @@ from lightx2v.models.input_encoders.hf.helios import HeliosTextEncoder from lightx2v.models.networks.helios import HeliosModel from lightx2v.models.runners.default_runner import DefaultRunner -from lightx2v.models.runners.helios.runtime_utils import apply_image_condition_noise, finalize_video_output, pt_video_output_to_comfy_frames from lightx2v.models.schedulers.helios import HeliosDistilledScheduler from lightx2v.models.video_encoders.hf.helios import HeliosVAE from lightx2v.server.metrics import monitor_cli @@ -36,13 +35,63 @@ def randn_tensor(shape, generator=None, device=None, dtype=None): return torch.randn(shape, generator=generator, device=device, dtype=dtype) -@RUNNER_REGISTER("helios") +def _apply_image_condition_noise( + image_latents, + fake_image_latents, + generator, + device, + image_noise_sigma_min, + image_noise_sigma_max, + video_noise_sigma_min, + video_noise_sigma_max, +): + image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + image_noise_sigma_min + ) + image_latents = ( + image_noise_sigma * torch.randn(image_latents.shape, generator=generator, device=device) + (1 - image_noise_sigma) * image_latents + ) + fake_image_noise_sigma = ( + torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + video_noise_sigma_min + ) + fake_image_latents = ( + fake_image_noise_sigma * torch.randn(fake_image_latents.shape, generator=generator, device=device) + + (1 - fake_image_noise_sigma) * fake_image_latents + ) + return image_latents, fake_image_latents + + +def _trim_generated_frames(frame_count, temporal_scale_factor): + return ((frame_count - 1) // temporal_scale_factor) * temporal_scale_factor + 1 + + +def _finalize_video_output(history_video, video_processor, temporal_scale_factor, output_type="pt"): + generated_frames = _trim_generated_frames(history_video.size(2), temporal_scale_factor) + history_video = history_video[:, :, :generated_frames] + return video_processor.postprocess_video(history_video, output_type=output_type) + + +def _pt_video_output_to_frames(video): + if video.dim() != 5: + raise ValueError(f"Expected [B, T, C, H, W] tensor, got shape {tuple(video.shape)}") + return video.permute(0, 1, 3, 4, 2).flatten(0, 1).cpu() + + +@RUNNER_REGISTER("helios_distilled") class HeliosRunner(DefaultRunner): def __init__(self, config): super().__init__(config) self.keep_first_frame = self.config.get("keep_first_frame", True) self.request_generator = None + def set_inputs(self, inputs): + self.request_generator = None + super().set_inputs(inputs) + + def end_run(self): + self.request_generator = None + super().end_run() + def get_request_generator(self): if self.request_generator is None: self.request_generator = torch.Generator(device=AI_DEVICE).manual_seed(self.input_info.seed) @@ -71,15 +120,21 @@ def load_vae(self): return vae, vae def init_modules(self): - super().init_modules() if self.config["task"] not in ["t2v", "i2v"]: raise NotImplementedError(f"HeliosRunner only supports t2v/i2v, got {self.config['task']}") + if self.config.get("lazy_load"): + raise NotImplementedError("Helios native integration does not support lazy_load.") + if self.config.get("unload_modules"): + raise NotImplementedError("Helios native integration does not support unload_modules.") + if self.config.get("cpu_offload"): + raise NotImplementedError("Helios native integration does not support generic cpu_offload.") if self.config.get("compile"): raise NotImplementedError("Helios native integration does not support compile yet.") if self.config.get("enable_low_vram_mode"): raise NotImplementedError("Helios native integration does not support group offload yet.") if self.config.get("enable_parallelism"): raise NotImplementedError("Helios native integration does not support context parallelism yet.") + super().init_modules() def get_latent_shape_with_target_hw(self): target_height = self.input_info.target_shape[0] if self.input_info.target_shape and len(self.input_info.target_shape) == 2 else self.config["target_height"] @@ -137,7 +192,7 @@ def _run_input_encoder_local_i2v(self): width=self.config["target_width"], dtype=torch.float32, ) - image_latents, fake_image_latents = apply_image_condition_noise( + image_latents, fake_image_latents = _apply_image_condition_noise( image_latents=image_latents, fake_image_latents=fake_image_latents, generator=generator, @@ -369,8 +424,8 @@ def run_main(self): history_video = current_video if history_video is None else torch.cat([history_video, current_video], dim=2) self.gen_video = history_video - self.gen_video_final = pt_video_output_to_comfy_frames( - finalize_video_output( + self.gen_video_final = _pt_video_output_to_frames( + _finalize_video_output( history_video=self.gen_video, video_processor=self.vae_decoder.video_processor, temporal_scale_factor=self.vae_decoder.vae_scale_factor_temporal, diff --git a/lightx2v/models/runners/helios/runtime_utils.py b/lightx2v/models/runners/helios/runtime_utils.py deleted file mode 100644 index b8776045f..000000000 --- a/lightx2v/models/runners/helios/runtime_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch - - -def apply_image_condition_noise( - image_latents, - fake_image_latents, - generator, - device, - image_noise_sigma_min, - image_noise_sigma_max, - video_noise_sigma_min, - video_noise_sigma_max, -): - image_noise_sigma = ( - torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + image_noise_sigma_min - ) - image_latents = ( - image_noise_sigma * torch.randn(image_latents.shape, generator=generator, device=device) + (1 - image_noise_sigma) * image_latents - ) - fake_image_noise_sigma = ( - torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + video_noise_sigma_min - ) - fake_image_latents = ( - fake_image_noise_sigma * torch.randn(fake_image_latents.shape, generator=generator, device=device) - + (1 - fake_image_noise_sigma) * fake_image_latents - ) - return image_latents, fake_image_latents - - -def trim_generated_frames(frame_count, temporal_scale_factor): - return ((frame_count - 1) // temporal_scale_factor) * temporal_scale_factor + 1 - - -def finalize_video_output(history_video, video_processor, temporal_scale_factor, output_type="pt"): - generated_frames = trim_generated_frames(history_video.size(2), temporal_scale_factor) - history_video = history_video[:, :, :generated_frames] - return video_processor.postprocess_video(history_video, output_type=output_type) - - -def pt_video_output_to_comfy_frames(video): - if video.dim() != 5: - raise ValueError(f"Expected [B, T, C, H, W] tensor, got shape {tuple(video.shape)}") - return video.permute(0, 1, 3, 4, 2).flatten(0, 1).cpu() diff --git a/lightx2v/models/schedulers/helios/scheduler.py b/lightx2v/models/schedulers/helios/scheduler.py index 18355cbc8..7ebdee34f 100644 --- a/lightx2v/models/schedulers/helios/scheduler.py +++ b/lightx2v/models/schedulers/helios/scheduler.py @@ -14,6 +14,8 @@ def __init__(self, config): self.sample_guide_scale = config.get("sample_guide_scale", 1.0) self.generator = None self.latents = None + self.latent_shape = None + self.image_encoder_output = None self.timesteps = None self.sigmas = None @@ -42,6 +44,9 @@ def step_pre(self, step_index): self.step_index = step_index def clear(self): + self.generator = None self.latents = None + self.latent_shape = None + self.image_encoder_output = None self.timesteps = None self.sigmas = None diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index 247b690f9..9bbaf360b 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -91,7 +91,6 @@ def __init__( self.low_noise_original_ckpt = low_noise_original_ckpt self.high_noise_original_ckpt = high_noise_original_ckpt self.transformer_model_name = transformer_model_name - if self.model_cls in [ "wan2.1", "wan2.1_distill", @@ -117,7 +116,7 @@ def __init__( elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: self.vae_stride = (4, 16, 16) self.num_channels_latents = 32 - elif self.model_cls in ["helios"]: + elif self.model_cls in ["helios_distilled"]: self.vae_stride = (4, 8, 8) self.num_channels_latents = 16 elif self.model_cls in ["ltx2"]: @@ -252,7 +251,7 @@ def set_infer_config( self.self_attn_1_type = attn_mode self.cross_attn_1_type = attn_mode self.cross_attn_2_type = attn_mode - elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill", "qwen_image", "longcat_image", "ltx2", "z_image", "helios"]: + elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill", "qwen_image", "longcat_image", "ltx2", "z_image", "helios_distilled"]: self.attn_type = attn_mode self.norm_modulate_backend = norm_modulate_backend @@ -308,7 +307,7 @@ def enable_quantize( self.qwen25vl_quantized = text_encoder_quantized self.qwen25vl_quantized_ckpt = text_encoder_quantized_ckpt self.qwen25vl_quant_scheme = text_encoder_quant_scheme - elif self.model_cls == "helios": + elif self.model_cls == "helios_distilled": self.text_encoder_quantized = text_encoder_quantized self.text_encoder_quantized_ckpt = text_encoder_quantized_ckpt self.text_encoder_quant_scheme = text_encoder_quant_scheme @@ -353,7 +352,7 @@ def enable_offload( self.qwen25vl_cpu_offload = text_encoder_offload self.siglip_cpu_offload = image_encoder_offload self.byt5_cpu_offload = image_encoder_offload - elif self.model_cls == "helios": + elif self.model_cls == "helios_distilled": self.text_encoder_cpu_offload = text_encoder_offload elif self.model_cls in ["qwen_image", "longcat_image"]: self.qwen25vl_cpu_offload = text_encoder_offload diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index f90faa1d2..6c559c05d 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -121,7 +121,7 @@ def auto_calc_config(config): config["config_path"] = wm_config if wm_ckpt: config["ckpt_path"] = wm_ckpt - elif config["model_cls"] == "helios": + elif config["model_cls"] == "helios_distilled": transformer_path = os.path.join(config["model_path"], "transformer") config["transformer_model_path"] = transformer_path config["text_encoder_path"] = config.get("text_encoder_path", os.path.join(config["model_path"], "text_encoder")) @@ -131,14 +131,21 @@ def auto_calc_config(config): config["max_sequence_length"] = config.get("max_sequence_length", 512) model_index_path = os.path.join(config["model_path"], "model_index.json") + modular_model_index_path = os.path.join(config["model_path"], "modular_model_index.json") + model_index = {} + modular_model_index = {} if os.path.exists(model_index_path): with open(model_index_path, "r", encoding="utf-8") as f: model_index = json.load(f) - config["is_distilled"] = bool(model_index.get("is_distilled", config.get("is_distilled", False))) - config["model_variant"] = config.get("model_variant", "distilled" if config["is_distilled"] else "base") scheduler_entry = model_index.get("scheduler") if isinstance(scheduler_entry, list) and len(scheduler_entry) >= 2: config["scheduler_type"] = scheduler_entry[1] + if os.path.exists(modular_model_index_path): + with open(modular_model_index_path, "r", encoding="utf-8") as f: + modular_model_index = json.load(f) + scheduler_entry = modular_model_index.get("scheduler") + if isinstance(scheduler_entry, list) and len(scheduler_entry) >= 2 and not config.get("scheduler_type"): + config["scheduler_type"] = scheduler_entry[1] if os.path.exists(os.path.join(transformer_path, "config.json")): with open(os.path.join(transformer_path, "config.json"), "r", encoding="utf-8") as f: @@ -153,6 +160,18 @@ def auto_calc_config(config): for key in ("stages", "stage_range", "shift", "prediction_type", "time_shift_type", "use_dynamic_shifting", "use_flow_sigmas"): if key in scheduler_config: config[key] = scheduler_config[key] + + is_distilled = bool(model_index.get("is_distilled", config.get("is_distilled", False))) or "Distilled" in (modular_model_index.get("_class_name") or "") or config.get("scheduler_type") == "HeliosDMDScheduler" + config["is_distilled"] = is_distilled + if not is_distilled: + scheduler_hint = config.get("scheduler_type", "unknown") + raise ValueError( + f"Unsupported Helios checkpoint at {config['model_path']}: " + f"LightX2V only supports Helios-Distilled checkpoints, but detected base/unsupported metadata " + f"(scheduler={scheduler_hint})." + ) + config["model_cls"] = "helios_distilled" + config["model_variant"] = "distilled" elif config["model_cls"] == "longcat_image": # Special config for longcat_image: load both root and transformer config if os.path.exists(os.path.join(config["model_path"], "config.json")): with open(os.path.join(config["model_path"], "config.json"), "r") as f: diff --git a/lightx2v/utils/utils.py b/lightx2v/utils/utils.py index 5e0826700..fca2d1344 100755 --- a/lightx2v/utils/utils.py +++ b/lightx2v/utils/utils.py @@ -736,7 +736,7 @@ def validate_config_paths(config: dict) -> None: check_path_exists(config["dit_original_ckpt"]) logger.debug(f"✓ Verified dit_original_ckpt: {config['dit_original_ckpt']}") - if config.get("model_cls") == "helios": + if config.get("model_cls") == "helios_distilled": for key in ("transformer_model_path", "text_encoder_path", "tokenizer_path", "vae_path", "scheduler_path"): if key in config and config[key] is not None: check_path_exists(config[key]) From 0eaccf7b12a01fd546e14c1c47a22c3597031ce7 Mon Sep 17 00:00:00 2001 From: xlycae Date: Fri, 29 May 2026 14:33:04 +0800 Subject: [PATCH 3/6] chore: drop local helios test scaffolding --- app/utils/model_utils.py | 1 - test_cases/test_helios_consistency_helpers.py | 195 ----------------- test_cases/test_helios_distilled_support.py | 198 ------------------ 3 files changed, 394 deletions(-) delete mode 100644 test_cases/test_helios_consistency_helpers.py delete mode 100644 test_cases/test_helios_distilled_support.py diff --git a/app/utils/model_utils.py b/app/utils/model_utils.py index d085a3732..c40181d86 100644 --- a/app/utils/model_utils.py +++ b/app/utils/model_utils.py @@ -22,7 +22,6 @@ MS_AVAILABLE = False import gc import importlib.util -import json import re import psutil diff --git a/test_cases/test_helios_consistency_helpers.py b/test_cases/test_helios_consistency_helpers.py deleted file mode 100644 index 7a8f2669b..000000000 --- a/test_cases/test_helios_consistency_helpers.py +++ /dev/null @@ -1,195 +0,0 @@ -import importlib.util -import sys -import types -import unittest -from pathlib import Path - -import torch - - -REPO_ROOT = Path(__file__).resolve().parents[1] - - -def load_module(module_name: str, file_path: Path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - module = importlib.util.module_from_spec(spec) - assert spec.loader is not None - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - -class HeliosPromptPackingTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - transformers_module = types.ModuleType("transformers") - transformers_module.AutoTokenizer = object - transformers_module.UMT5EncoderModel = object - sys.modules.setdefault("transformers", transformers_module) - - envs_module = types.ModuleType("lightx2v.utils.envs") - envs_module.GET_DTYPE = lambda: torch.bfloat16 - sys.modules.setdefault("lightx2v.utils.envs", envs_module) - - global_var_module = types.ModuleType("lightx2v_platform.base.global_var") - global_var_module.AI_DEVICE = "cpu" - sys.modules.setdefault("lightx2v_platform.base.global_var", global_var_module) - - cls.text_module = load_module( - "test_helios_text_model", - REPO_ROOT / "lightx2v/models/input_encoders/hf/helios/model.py", - ) - - def test_pack_prompt_embeds_reapplies_sequence_lengths_before_padding(self): - hidden_state = torch.tensor( - [ - [[1.0, 10.0], [2.0, 20.0], [999.0, 999.0], [999.0, 999.0]], - [[3.0, 30.0], [4.0, 40.0], [5.0, 50.0], [999.0, 999.0]], - ] - ) - attention_mask = torch.tensor( - [ - [1, 1, 0, 0], - [1, 1, 1, 0], - ] - ) - - prompt_embeds, mask = self.text_module.pack_t5_prompt_embeds( - hidden_state, - attention_mask, - max_sequence_length=4, - num_videos_per_prompt=2, - dtype=torch.bfloat16, - device=torch.device("cpu"), - ) - - self.assertEqual(tuple(prompt_embeds.shape), (4, 4, 2)) - self.assertEqual(prompt_embeds.dtype, torch.bfloat16) - self.assertTrue(mask.dtype == torch.bool) - self.assertTrue(torch.equal(mask, attention_mask.bool())) - self.assertTrue(torch.equal(prompt_embeds[0, 2:], torch.zeros((2, 2), dtype=torch.bfloat16))) - self.assertTrue(torch.equal(prompt_embeds[1], prompt_embeds[0])) - self.assertTrue(torch.equal(prompt_embeds[2, 3:], torch.zeros((1, 2), dtype=torch.bfloat16))) - self.assertEqual(prompt_embeds[2, 2, 0].item(), 5.0) - - -class HeliosRuntimeUtilsTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.runtime_utils = load_module( - "test_helios_runtime_utils", - REPO_ROOT / "lightx2v/models/runners/helios/runtime_utils.py", - ) - - def test_apply_image_condition_noise_uses_distinct_sigmas_and_generator_order(self): - image_latents = torch.ones((1, 1, 1, 1, 1), dtype=torch.float32) - fake_image_latents = torch.full((1, 1, 1, 1, 1), 2.0, dtype=torch.float32) - generator = torch.Generator(device="cpu").manual_seed(123) - - noisy_image, noisy_fake = self.runtime_utils.apply_image_condition_noise( - image_latents=image_latents, - fake_image_latents=fake_image_latents, - generator=generator, - device=torch.device("cpu"), - image_noise_sigma_min=0.111, - image_noise_sigma_max=0.135, - video_noise_sigma_min=0.211, - video_noise_sigma_max=0.235, - ) - - ref_generator = torch.Generator(device="cpu").manual_seed(123) - image_sigma = torch.rand(1, device="cpu", generator=ref_generator) * (0.135 - 0.111) + 0.111 - ref_noisy_image = image_sigma * torch.randn(image_latents.shape, generator=ref_generator) + (1 - image_sigma) * image_latents - fake_sigma = torch.rand(1, device="cpu", generator=ref_generator) * (0.235 - 0.211) + 0.211 - ref_noisy_fake = fake_sigma * torch.randn(fake_image_latents.shape, generator=ref_generator) + (1 - fake_sigma) * fake_image_latents - - self.assertTrue(torch.allclose(noisy_image, ref_noisy_image)) - self.assertTrue(torch.allclose(noisy_fake, ref_noisy_fake)) - - def test_trim_and_postprocess_video_matches_helios_frame_rule(self): - history_video = torch.arange(1 * 3 * 99 * 2 * 2, dtype=torch.float32).reshape(1, 3, 99, 2, 2) - - class DummyVideoProcessor: - def __init__(self): - self.called = False - self.last_shape = None - - def postprocess_video(self, video, output_type="np"): - self.called = True - self.last_shape = tuple(video.shape) - return {"frames": video.clone(), "output_type": output_type} - - processor = DummyVideoProcessor() - result = self.runtime_utils.finalize_video_output( - history_video=history_video, - video_processor=processor, - temporal_scale_factor=4, - output_type="np", - ) - - self.assertTrue(processor.called) - self.assertEqual(processor.last_shape, (1, 3, 97, 2, 2)) - self.assertEqual(tuple(result["frames"].shape), (1, 3, 97, 2, 2)) - self.assertEqual(result["output_type"], "np") - - def test_pt_video_output_is_converted_to_comfy_frame_layout(self): - pt_video = torch.arange(1 * 2 * 3 * 2 * 2, dtype=torch.float32).reshape(1, 2, 3, 2, 2) - frames = self.runtime_utils.pt_video_output_to_comfy_frames(pt_video) - self.assertEqual(tuple(frames.shape), (2, 2, 2, 3)) - self.assertTrue(torch.equal(frames[0, 0, 0], torch.tensor([0.0, 4.0, 8.0]))) - - -class HeliosI2VGeneratorContinuityTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - base_scheduler_module = types.ModuleType("lightx2v.models.schedulers.scheduler") - - class BaseScheduler: - def __init__(self, config): - self.config = config - - base_scheduler_module.BaseScheduler = BaseScheduler - sys.modules.setdefault("lightx2v.models.schedulers.scheduler", base_scheduler_module) - - fake_dmd_module = types.ModuleType("lightx2v.models.schedulers.helios.helios_dmd") - - class FakeInnerScheduler: - config = types.SimpleNamespace() - - @classmethod - def from_pretrained(cls, _path): - return cls() - - fake_dmd_module.HeliosDMDScheduler = FakeInnerScheduler - sys.modules.setdefault("lightx2v.models.schedulers.helios.helios_dmd", fake_dmd_module) - - global_var_module = types.ModuleType("lightx2v_platform.base.global_var") - global_var_module.AI_DEVICE = "cpu" - sys.modules["lightx2v_platform.base.global_var"] = global_var_module - - cls.scheduler_module = load_module( - "test_helios_scheduler_module", - REPO_ROOT / "lightx2v/models/schedulers/helios/scheduler.py", - ) - - def test_prepare_reuses_external_generator_for_i2v_rng_continuity(self): - scheduler = self.scheduler_module.HeliosDistilledScheduler( - { - "scheduler_path": "/tmp/unused", - "pyramid_num_inference_steps_list": [2, 2, 2], - "sample_guide_scale": 1.0, - } - ) - external_generator = torch.Generator(device="cpu").manual_seed(42) - scheduler.prepare(seed=999, latent_shape=[16, 25, 48, 80], image_encoder_output={}, generator=external_generator) - self.assertIs(scheduler.generator, external_generator) - - def test_helios_runner_i2v_no_longer_reseeds_a_second_generator(self): - runner_path = REPO_ROOT / "lightx2v/models/runners/helios/helios_runner.py" - source = runner_path.read_text(encoding="utf-8") - i2v_block = source.split("def _run_input_encoder_local_i2v", 1)[1].split("def sample_block_noise", 1)[0] - self.assertNotIn("manual_seed(self.input_info.seed)", i2v_block) - - -if __name__ == "__main__": - unittest.main() diff --git a/test_cases/test_helios_distilled_support.py b/test_cases/test_helios_distilled_support.py deleted file mode 100644 index 5af5c1815..000000000 --- a/test_cases/test_helios_distilled_support.py +++ /dev/null @@ -1,198 +0,0 @@ -import argparse -import importlib.machinery -import os -import sys -import tempfile -import types -import unittest - - -REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -APP_ROOT = os.path.join(REPO_ROOT, "app") - -if REPO_ROOT not in sys.path: - sys.path.insert(0, REPO_ROOT) -if APP_ROOT not in sys.path: - sys.path.insert(0, APP_ROOT) - -lightx2v_pkg = types.ModuleType("lightx2v") -lightx2v_pkg.__path__ = [os.path.join(REPO_ROOT, "lightx2v")] -sys.modules.setdefault("lightx2v", lightx2v_pkg) - -lightx2v_platform_pkg = types.ModuleType("lightx2v_platform") -lightx2v_platform_pkg.__path__ = [os.path.join(REPO_ROOT, "lightx2v_platform")] -sys.modules.setdefault("lightx2v_platform", lightx2v_platform_pkg) - -lightx2v_platform_base_pkg = types.ModuleType("lightx2v_platform.base") -lightx2v_platform_base_pkg.__path__ = [os.path.join(REPO_ROOT, "lightx2v_platform", "base")] -sys.modules.setdefault("lightx2v_platform.base", lightx2v_platform_base_pkg) - -global_var_module = types.ModuleType("lightx2v_platform.base.global_var") -global_var_module.AI_DEVICE = "cuda" -sys.modules.setdefault("lightx2v_platform.base.global_var", global_var_module) - -lightx2v_utils_module = types.ModuleType("lightx2v.utils.utils") -lightx2v_utils_module.is_main_process = lambda: True -sys.modules.setdefault("lightx2v.utils.utils", lightx2v_utils_module) - -if "loguru" not in sys.modules: - sys.modules["loguru"] = types.SimpleNamespace( - logger=types.SimpleNamespace(info=lambda *a, **k: None, warning=lambda *a, **k: None, debug=lambda *a, **k: None) - ) -if "psutil" not in sys.modules: - sys.modules["psutil"] = types.SimpleNamespace(virtual_memory=lambda: types.SimpleNamespace(available=0)) -if "huggingface_hub" not in sys.modules: - hf_module = types.ModuleType("huggingface_hub") - hf_module.HfApi = object - hf_module.list_repo_files = lambda *a, **k: [] - hf_module.__spec__ = importlib.machinery.ModuleSpec("huggingface_hub", loader=None) - sys.modules["huggingface_hub"] = hf_module -if "modelscope" not in sys.modules: - modelscope_module = types.ModuleType("modelscope") - modelscope_module.__spec__ = importlib.machinery.ModuleSpec("modelscope", loader=None) - hub_module = types.ModuleType("modelscope.hub") - hub_module.__spec__ = importlib.machinery.ModuleSpec("modelscope.hub", loader=None) - api_module = types.ModuleType("modelscope.hub.api") - api_module.__spec__ = importlib.machinery.ModuleSpec("modelscope.hub.api", loader=None) - api_module.HubApi = object - sys.modules["modelscope"] = modelscope_module - sys.modules["modelscope.hub"] = hub_module - sys.modules["modelscope.hub.api"] = api_module -if "torch" not in sys.modules: - torch_module = types.ModuleType("torch") - torch_module.__spec__ = importlib.machinery.ModuleSpec("torch", loader=None) - torch_module.float16 = "float16" - torch_module.float32 = "float32" - torch_module.bfloat16 = "bfloat16" - torch_module.Tensor = object - torch_module._scaled_mm = object() - torch_module.cuda = types.SimpleNamespace( - is_available=lambda: False, - get_device_capability=lambda *_: (0, 0), - get_device_name=lambda *_: "", - empty_cache=lambda: None, - synchronize=lambda: None, - ) - torch_module.device = lambda value: value - - dist_module = types.ModuleType("torch.distributed") - dist_module.is_initialized = lambda: False - dist_module.get_rank = lambda: 0 - dist_module.get_world_size = lambda: 1 - dist_module.all_reduce = lambda *_args, **_kwargs: None - - tensor_module = types.ModuleType("torch.distributed.tensor") - device_mesh_module = types.ModuleType("torch.distributed.tensor.device_mesh") - device_mesh_module.init_device_mesh = lambda *_args, **_kwargs: None - - torch_module.distributed = dist_module - sys.modules["torch"] = torch_module - sys.modules["torch.distributed"] = dist_module - sys.modules["torch.distributed.tensor"] = tensor_module - sys.modules["torch.distributed.tensor.device_mesh"] = device_mesh_module - -from utils.model_utils import get_model_configs -from lightx2v.utils.set_config import set_config - - -class HeliosDistilledSupportTest(unittest.TestCase): - def test_get_model_configs_detects_helios_distilled_variant(self): - config = get_model_configs( - model_type_input="Helios", - model_path_input="/data1/models/BestWishYSH/Helios-Distilled", - dit_path_input=None, - high_noise_path_input=None, - low_noise_path_input=None, - t5_path_input=None, - clip_path_input=None, - vae_path_input=None, - qwen_image_dit_path_input=None, - qwen_image_vae_path_input=None, - qwen_image_scheduler_path_input=None, - qwen25vl_encoder_path_input=None, - z_image_dit_path_input=None, - z_image_vae_path_input=None, - z_image_scheduler_path_input=None, - qwen3_encoder_path_input=None, - quant_op="triton", - ) - - self.assertEqual(config["model_cls"], "helios") - self.assertEqual(config["model_variant"], "distilled") - self.assertEqual(config["scheduler_type"], "HeliosDMDScheduler") - self.assertEqual(config["model_path"], "/data1/models/BestWishYSH/Helios-Distilled") - self.assertEqual(config["transformer_model_path"], "/data1/models/BestWishYSH/Helios-Distilled/transformer") - self.assertEqual(config["text_encoder_path"], "/data1/models/BestWishYSH/Helios-Distilled/text_encoder") - self.assertEqual(config["tokenizer_path"], "/data1/models/BestWishYSH/Helios-Distilled/tokenizer") - self.assertEqual(config["vae_path"], "/data1/models/BestWishYSH/Helios-Distilled/vae") - self.assertEqual(config["scheduler_path"], "/data1/models/BestWishYSH/Helios-Distilled/scheduler") - self.assertTrue(config["is_distilled"]) - - def test_set_config_loads_helios_transformer_and_scheduler_metadata(self): - with tempfile.TemporaryDirectory() as tmpdir: - model_root = os.path.join(tmpdir, "Helios-Distilled") - os.makedirs(os.path.join(model_root, "transformer")) - os.makedirs(os.path.join(model_root, "scheduler")) - os.makedirs(os.path.join(model_root, "text_encoder")) - os.makedirs(os.path.join(model_root, "tokenizer")) - os.makedirs(os.path.join(model_root, "vae")) - - with open(os.path.join(model_root, "configuration.json"), "w", encoding="utf-8") as f: - f.write('{"model_type": "helios"}') - with open(os.path.join(model_root, "model_index.json"), "w", encoding="utf-8") as f: - f.write( - '{"_class_name":"HeliosPyramidPipeline","is_distilled":true,' - '"scheduler":["diffusers","HeliosDMDScheduler"],' - '"transformer":["diffusers","HeliosTransformer3DModel"],' - '"text_encoder":["transformers","UMT5EncoderModel"],' - '"tokenizer":["transformers","T5TokenizerFast"],' - '"vae":["diffusers","AutoencoderKLWan"]}' - ) - with open(os.path.join(model_root, "transformer", "config.json"), "w", encoding="utf-8") as f: - f.write('{"num_layers": 40, "patch_size": [1, 2, 2], "in_channels": 16, "out_channels": 16}') - with open(os.path.join(model_root, "scheduler", "scheduler_config.json"), "w", encoding="utf-8") as f: - f.write('{"_class_name":"HeliosDMDScheduler","stages":3}') - with open(os.path.join(model_root, "vae", "config.json"), "w", encoding="utf-8") as f: - f.write('{"temperal_downsample":[false,true,true]}') - - args = argparse.Namespace( - model_cls="helios", - model_variant="distilled", - task="t2v", - model_path=model_root, - target_video_length=99, - ) - - config = set_config(args) - self.assertEqual(config["scheduler_type"], "HeliosDMDScheduler") - self.assertEqual(config["num_layers"], 40) - self.assertEqual(config["patch_size"], [1, 2, 2]) - self.assertEqual(config["vae_scale_factor"], 8) - self.assertTrue(config["is_distilled"]) - - def test_helios_runner_is_native_not_pipeline_bridge(self): - runner_path = os.path.join(REPO_ROOT, "lightx2v", "models", "runners", "helios", "helios_runner.py") - with open(runner_path, "r", encoding="utf-8") as f: - source = f.read() - - self.assertIn("class HeliosRunner", source) - self.assertNotIn("HeliosPyramidPipeline", source) - self.assertNotIn("HeliosPipeline", source) - - def test_infer_cli_exposes_helios_model_cls(self): - infer_path = os.path.join(REPO_ROOT, "lightx2v", "infer.py") - with open(infer_path, "r", encoding="utf-8") as f: - source = f.read() - - self.assertIn('"helios"', source) - - def test_validate_config_paths_has_helios_branch(self): - utils_path = os.path.join(REPO_ROOT, "lightx2v", "utils", "utils.py") - with open(utils_path, "r", encoding="utf-8") as f: - source = f.read() - - self.assertIn('config.get("model_cls") == "helios"', source) - - -if __name__ == "__main__": - unittest.main() From ef9c7754b6b3f6d9c782009216e69a18a0a8c033 Mon Sep 17 00:00:00 2001 From: xlycae Date: Fri, 29 May 2026 15:53:22 +0800 Subject: [PATCH 4/6] fix: address helios review issues --- lightx2v/infer.py | 2 +- lightx2v/models/networks/helios/model.py | 1 - .../networks/helios/transformer_helios.py | 52 +++++-------------- .../models/runners/helios/helios_runner.py | 37 ++++++------- .../models/schedulers/helios/helios_dmd.py | 13 ++--- .../models/video_encoders/hf/helios/vae.py | 5 +- lightx2v/utils/set_config.py | 10 ++-- scripts/helios/run_helios_distilled_i2v.sh | 2 +- scripts/helios/run_helios_distilled_t2v.sh | 2 +- 9 files changed, 43 insertions(+), 81 deletions(-) diff --git a/lightx2v/infer.py b/lightx2v/infer.py index 6a9998f58..5c0bfab4e 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -7,12 +7,12 @@ from lightx2v.common.ops import * from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401 +from lightx2v.models.runners.helios.helios_runner import HeliosRunner # noqa: F401 from lightx2v.models.runners.hunyuan3d.hunyuan3d_shape_runner import Hunyuan3DShapeRunner # noqa: F401 # from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 -from lightx2v.models.runners.helios.helios_runner import HeliosRunner # noqa: F401 from lightx2v.models.runners.longcat_image.longcat_image_runner import LongCatImageRunner # noqa: F401 from lightx2v.models.runners.ltx2.ltx2_runner import LTX2Runner # noqa: F401 from lightx2v.models.runners.motus.motus_runner import MotusRunner # noqa: F401 diff --git a/lightx2v/models/networks/helios/model.py b/lightx2v/models/networks/helios/model.py index ad62b6f44..6dbc4a5de 100644 --- a/lightx2v/models/networks/helios/model.py +++ b/lightx2v/models/networks/helios/model.py @@ -5,7 +5,6 @@ from lightx2v.models.networks.helios.transformer_helios import HeliosTransformer3DModel from lightx2v.utils.envs import GET_DTYPE -from lightx2v_platform.base.global_var import AI_DEVICE class HeliosModel: diff --git a/lightx2v/models/networks/helios/transformer_helios.py b/lightx2v/models/networks/helios/transformer_helios.py index f311be0af..63313325b 100644 --- a/lightx2v/models/networks/helios/transformer_helios.py +++ b/lightx2v/models/networks/helios/transformer_helios.py @@ -18,7 +18,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models._modeling_parallel import ContextParallelInput, ContextParallelOutput @@ -32,7 +31,6 @@ from diffusers.utils import apply_lora_scale, logging from diffusers.utils.torch_utils import maybe_allow_in_graph - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -102,9 +100,7 @@ class HeliosAttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher." - ) + raise ImportError("HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher.") def __call__( self, @@ -228,18 +224,14 @@ def fuse_projections(self): out_features, in_features = concatenated_weights.shape with torch.device("meta"): self.to_qkv = nn.Linear(in_features, out_features, bias=True) - self.to_qkv.load_state_dict( - {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True - ) + self.to_qkv.load_state_dict({"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True) else: concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) out_features, in_features = concatenated_weights.shape with torch.device("meta"): self.to_kv = nn.Linear(in_features, out_features, bias=True) - self.to_kv.load_state_dict( - {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True - ) + self.to_kv.load_state_dict({"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True) if self.added_kv_proj_dim is not None: concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data]) @@ -247,9 +239,7 @@ def fuse_projections(self): out_features, in_features = concatenated_weights.shape with torch.device("meta"): self.to_added_kv = nn.Linear(in_features, out_features, bias=True) - self.to_added_kv.load_state_dict( - {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True - ) + self.to_added_kv.load_state_dict({"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True) self.fused_projections = True @@ -430,9 +420,7 @@ def forward( original_context_length: int = None, ) -> torch.Tensor: if temb.ndim == 4: - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table.unsqueeze(0) + temb.float() - ).chunk(6, dim=2) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table.unsqueeze(0) + temb.float()).chunk(6, dim=2) # batch_size, seq_len, 1, inner_dim shift_msa = shift_msa.squeeze(2) scale_msa = scale_msa.squeeze(2) @@ -441,9 +429,7 @@ def forward( c_scale_msa = c_scale_msa.squeeze(2) c_gate_msa = c_gate_msa.squeeze(2) else: - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( - self.scale_shift_table + temb.float() - ).chunk(6, dim=1) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb.float()).chunk(6, dim=1) # 1. Self-attention norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) @@ -460,9 +446,7 @@ def forward( if self.guidance_cross_attn: history_seq_len = hidden_states.shape[1] - original_context_length - history_hidden_states, hidden_states = torch.split( - hidden_states, [history_seq_len, original_context_length], dim=1 - ) + history_hidden_states, hidden_states = torch.split(hidden_states, [history_seq_len, original_context_length], dim=1) norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) attn_output = self.attn2( norm_hidden_states, @@ -485,18 +469,14 @@ def forward( hidden_states = hidden_states + attn_output # 3. Feed-forward - norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( - hidden_states - ) + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) ff_output = self.ffn(norm_hidden_states) hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) return hidden_states -class HeliosTransformer3DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin -): +class HeliosTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): r""" A Transformer model for video-like data used in the Helios model. @@ -751,15 +731,9 @@ def forward( if indices_hidden_states is not None and self.zero_history_timestep: timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device) - temb_t0, timestep_proj_t0, _ = self.condition_embedder( - timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False - ) + temb_t0, timestep_proj_t0, _ = self.condition_embedder(timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False) temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1) - timestep_proj_t0 = ( - timestep_proj_t0.unflatten(-1, (6, -1)) - .view(1, 6, 1, -1) - .expand(batch_size, -1, history_context_length, -1) - ) + timestep_proj_t0 = timestep_proj_t0.unflatten(-1, (6, -1)).view(1, 6, 1, -1).expand(batch_size, -1, history_context_length, -1) temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states) timestep_proj = timestep_proj.unflatten(-1, (6, -1)) @@ -807,9 +781,7 @@ def forward( hidden_states = self.proj_out(hidden_states) # 8. Unpatchify - hidden_states = hidden_states.reshape( - batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 - ) + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1) hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) diff --git a/lightx2v/models/runners/helios/helios_runner.py b/lightx2v/models/runners/helios/helios_runner.py index ad25bed02..7a383b24f 100644 --- a/lightx2v/models/runners/helios/helios_runner.py +++ b/lightx2v/models/runners/helios/helios_runner.py @@ -12,7 +12,7 @@ from lightx2v.models.schedulers.helios import HeliosDistilledScheduler from lightx2v.models.video_encoders.hf.helios import HeliosVAE from lightx2v.server.metrics import monitor_cli -from lightx2v.utils.envs import GET_DTYPE, GET_RECORDER_MODE +from lightx2v.utils.envs import GET_RECORDER_MODE from lightx2v.utils.profiler import ProfilingContext4DebugL1, ProfilingContext4DebugL2 from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v_platform.base.global_var import AI_DEVICE @@ -45,19 +45,10 @@ def _apply_image_condition_noise( video_noise_sigma_min, video_noise_sigma_max, ): - image_noise_sigma = ( - torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + image_noise_sigma_min - ) - image_latents = ( - image_noise_sigma * torch.randn(image_latents.shape, generator=generator, device=device) + (1 - image_noise_sigma) * image_latents - ) - fake_image_noise_sigma = ( - torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + video_noise_sigma_min - ) - fake_image_latents = ( - fake_image_noise_sigma * torch.randn(fake_image_latents.shape, generator=generator, device=device) - + (1 - fake_image_noise_sigma) * fake_image_latents - ) + image_noise_sigma = torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min) + image_noise_sigma_min + image_latents = image_noise_sigma * torch.randn(image_latents.shape, generator=generator, device=device) + (1 - image_noise_sigma) * image_latents + fake_image_noise_sigma = torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min) + video_noise_sigma_min + fake_image_latents = fake_image_noise_sigma * torch.randn(fake_image_latents.shape, generator=generator, device=device) + (1 - fake_image_noise_sigma) * fake_image_latents return image_latents, fake_image_latents @@ -368,7 +359,9 @@ def run_main(self): gamma = self.scheduler.inner.config.gamma alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) - noise = self.sample_block_noise(batch_size, num_channels_latents, latents.shape[2], pyramid_height, pyramid_width, patch_size, device, self.scheduler.generator).to(dtype=transformer_dtype) + noise = self.sample_block_noise(batch_size, num_channels_latents, latents.shape[2], pyramid_height, pyramid_width, patch_size, device, self.scheduler.generator).to( + dtype=transformer_dtype + ) latents = alpha * latents + beta * noise start_point_list.append(latents) @@ -426,10 +419,10 @@ def run_main(self): self.gen_video = history_video self.gen_video_final = _pt_video_output_to_frames( _finalize_video_output( - history_video=self.gen_video, - video_processor=self.vae_decoder.video_processor, - temporal_scale_factor=self.vae_decoder.vae_scale_factor_temporal, - output_type="pt", + history_video=self.gen_video, + video_processor=self.vae_decoder.video_processor, + temporal_scale_factor=self.vae_decoder.vae_scale_factor_temporal, + output_type="pt", ) ) result = self.process_images_after_vae_decoder_helios() @@ -450,7 +443,11 @@ def process_images_after_vae_decoder_helios(self): if self.input_info.return_result_tensor: return {"video": self.gen_video_final} elif self.input_info.save_result_path is not None: - fps = self.config["video_frame_interpolation"]["target_fps"] if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps") else self.config.get("fps", 16) + fps = ( + self.config["video_frame_interpolation"]["target_fps"] + if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps") + else self.config.get("fps", 16) + ) if not dist.is_initialized() or dist.get_rank() == 0: out_path = self.input_info.save_result_path logger.info("🎬 Start to save video 🎬") diff --git a/lightx2v/models/schedulers/helios/helios_dmd.py b/lightx2v/models/schedulers/helios/helios_dmd.py index 3e3d330d5..ae5f4c763 100644 --- a/lightx2v/models/schedulers/helios/helios_dmd.py +++ b/lightx2v/models/schedulers/helios/helios_dmd.py @@ -18,7 +18,6 @@ import numpy as np import torch - from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import BaseOutput @@ -138,9 +137,7 @@ def init_sigmas_for_each_stage(self): timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999) timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)] timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1) - self.timesteps_per_stage[i_s] = ( - timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) - ) + self.timesteps_per_stage[i_s] = timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1]) stage_sigmas = np.linspace(0.999, 0, training_steps + 1) self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1]) @@ -193,9 +190,7 @@ def set_timesteps( if self.config.stages == 1: if sigmas is None: - sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype( - np.float32 - ) + sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(np.float32) if self.config.shift != 1.0: assert not self.config.use_dynamic_shifting sigmas = self.time_shift(self.config.shift, 1.0, sigmas) @@ -228,9 +223,7 @@ def set_timesteps( if self.config.stages == 1: self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps else: - self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * ( - self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min() - ) + self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * (self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min()) # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift def time_shift(self, mu: float, sigma: float, t: torch.Tensor): diff --git a/lightx2v/models/video_encoders/hf/helios/vae.py b/lightx2v/models/video_encoders/hf/helios/vae.py index 7b761e2a8..2c4c7dfe9 100644 --- a/lightx2v/models/video_encoders/hf/helios/vae.py +++ b/lightx2v/models/video_encoders/hf/helios/vae.py @@ -6,7 +6,6 @@ from diffusers.utils import load_image from diffusers.video_processor import VideoProcessor -from lightx2v.utils.envs import GET_DTYPE from lightx2v_platform.base.global_var import AI_DEVICE @@ -59,8 +58,8 @@ def prepare_image_latents(self, image, generator, num_latent_frames_per_chunk, h def decode(self, latents): self._to_device() - latents_mean = self.latents_mean.to(device=latents.device, dtype=latents.dtype) - latents_std = self.latents_std.to(device=latents.device, dtype=latents.dtype) + latents_mean = self.latents_mean.to(device=self.model.device, dtype=self.model.dtype) + latents_std = self.latents_std.to(device=self.model.device, dtype=self.model.dtype) current_latents = latents.to(self.model.device, dtype=self.model.dtype) / latents_std + latents_mean decoded = self.model.decode(current_latents, return_dict=False)[0] self._to_cpu() diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index 6c559c05d..381f79190 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -161,14 +161,16 @@ def auto_calc_config(config): if key in scheduler_config: config[key] = scheduler_config[key] - is_distilled = bool(model_index.get("is_distilled", config.get("is_distilled", False))) or "Distilled" in (modular_model_index.get("_class_name") or "") or config.get("scheduler_type") == "HeliosDMDScheduler" + is_distilled = ( + bool(model_index.get("is_distilled", config.get("is_distilled", False))) + or "Distilled" in (modular_model_index.get("_class_name") or "") + or config.get("scheduler_type") == "HeliosDMDScheduler" + ) config["is_distilled"] = is_distilled if not is_distilled: scheduler_hint = config.get("scheduler_type", "unknown") raise ValueError( - f"Unsupported Helios checkpoint at {config['model_path']}: " - f"LightX2V only supports Helios-Distilled checkpoints, but detected base/unsupported metadata " - f"(scheduler={scheduler_hint})." + f"Unsupported Helios checkpoint at {config['model_path']}: LightX2V only supports Helios-Distilled checkpoints, but detected base/unsupported metadata (scheduler={scheduler_hint})." ) config["model_cls"] = "helios_distilled" config["model_variant"] = "distilled" diff --git a/scripts/helios/run_helios_distilled_i2v.sh b/scripts/helios/run_helios_distilled_i2v.sh index 6c651c795..66d3b7146 100644 --- a/scripts/helios/run_helios_distilled_i2v.sh +++ b/scripts/helios/run_helios_distilled_i2v.sh @@ -9,7 +9,7 @@ export CUDA_VISIBLE_DEVICES=0 source ${lightx2v_path}/scripts/base/base.sh python -m lightx2v.infer \ ---model_cls helios \ +--model_cls helios_distilled \ --task i2v \ --model_path ${model_path} \ --config_json ${lightx2v_path}/configs/helios/helios_distilled_i2v.json \ diff --git a/scripts/helios/run_helios_distilled_t2v.sh b/scripts/helios/run_helios_distilled_t2v.sh index 3b94d811e..b3d121c85 100644 --- a/scripts/helios/run_helios_distilled_t2v.sh +++ b/scripts/helios/run_helios_distilled_t2v.sh @@ -8,7 +8,7 @@ export CUDA_VISIBLE_DEVICES=0 source ${lightx2v_path}/scripts/base/base.sh python -m lightx2v.infer \ ---model_cls helios \ +--model_cls helios_distilled \ --task t2v \ --model_path ${model_path} \ --config_json ${lightx2v_path}/configs/helios/helios_distilled_t2v.json \ From 981f094267b07a912f8da20ae2c241ba4db2b6ad Mon Sep 17 00:00:00 2001 From: xlycae Date: Fri, 29 May 2026 16:10:54 +0800 Subject: [PATCH 5/6] fix: sort imports in train infer entrypoint --- lightx2v_train/infer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lightx2v_train/infer.py b/lightx2v_train/infer.py index 8c533e6bd..2bc54a0b3 100644 --- a/lightx2v_train/infer.py +++ b/lightx2v_train/infer.py @@ -1,10 +1,11 @@ import argparse from lightx2v_train.data import build_data -from lightx2v_train.infer import build_inferencer from lightx2v_train.model_zoo import build_model from lightx2v_train.runtime import load_config +from lightx2v_train.infer import build_inferencer + def parse_args(): parser = argparse.ArgumentParser(description="Run inference with a trained LightX2V model.") From 83d54c4d32125613fd4d00e4caf467abb653f275 Mon Sep 17 00:00:00 2001 From: xlycae Date: Fri, 29 May 2026 17:18:20 +0800 Subject: [PATCH 6/6] fix: tighten helios distilled PR readiness --- .../models/input_encoders/hf/helios/model.py | 3 ++- lightx2v/models/runners/helios/__init__.py | 13 +++++++++ .../models/runners/helios/helios_runner.py | 2 ++ lightx2v_train/infer.py | 3 +-- scripts/helios/run_helios_distilled_i2v.sh | 27 ++++++++++++------- scripts/helios/run_helios_distilled_t2v.sh | 23 +++++++++++----- 6 files changed, 52 insertions(+), 19 deletions(-) diff --git a/lightx2v/models/input_encoders/hf/helios/model.py b/lightx2v/models/input_encoders/hf/helios/model.py index 440fc88a6..9aa198474 100644 --- a/lightx2v/models/input_encoders/hf/helios/model.py +++ b/lightx2v/models/input_encoders/hf/helios/model.py @@ -48,7 +48,8 @@ def pack_t5_prompt_embeds(hidden_state, attention_mask, max_sequence_length, num class HeliosTextEncoder: def __init__(self, config): self.config = config - self.device = torch.device("cpu") if config.get("t5_cpu_offload", config.get("cpu_offload", False)) else torch.device(AI_DEVICE) + use_cpu = config.get("text_encoder_cpu_offload", config.get("t5_cpu_offload", config.get("cpu_offload", False))) + self.device = torch.device("cpu") if use_cpu else torch.device(AI_DEVICE) self.dtype = GET_DTYPE() self.tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"]) self.text_encoder = UMT5EncoderModel.from_pretrained(config["text_encoder_path"], torch_dtype=self.dtype).to(self.device) diff --git a/lightx2v/models/runners/helios/__init__.py b/lightx2v/models/runners/helios/__init__.py index 8b1378917..a22260298 100644 --- a/lightx2v/models/runners/helios/__init__.py +++ b/lightx2v/models/runners/helios/__init__.py @@ -1 +1,14 @@ +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from lightx2v.models.runners.helios.helios_runner import HeliosRunner + +__all__ = ["HeliosRunner"] + + +def __getattr__(name): + if name == "HeliosRunner": + from lightx2v.models.runners.helios.helios_runner import HeliosRunner + + return HeliosRunner + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/lightx2v/models/runners/helios/helios_runner.py b/lightx2v/models/runners/helios/helios_runner.py index 7a383b24f..4192dc9a7 100644 --- a/lightx2v/models/runners/helios/helios_runner.py +++ b/lightx2v/models/runners/helios/helios_runner.py @@ -113,6 +113,8 @@ def load_vae(self): def init_modules(self): if self.config["task"] not in ["t2v", "i2v"]: raise NotImplementedError(f"HeliosRunner only supports t2v/i2v, got {self.config['task']}") + if self.config.get("text_encoder_quantized") or self.config.get("text_encoder_quantized_ckpt") or self.config.get("text_encoder_quant_scheme"): + raise NotImplementedError("Helios native integration does not support text-encoder quantization yet.") if self.config.get("lazy_load"): raise NotImplementedError("Helios native integration does not support lazy_load.") if self.config.get("unload_modules"): diff --git a/lightx2v_train/infer.py b/lightx2v_train/infer.py index 2bc54a0b3..8c533e6bd 100644 --- a/lightx2v_train/infer.py +++ b/lightx2v_train/infer.py @@ -1,11 +1,10 @@ import argparse from lightx2v_train.data import build_data +from lightx2v_train.infer import build_inferencer from lightx2v_train.model_zoo import build_model from lightx2v_train.runtime import load_config -from lightx2v_train.infer import build_inferencer - def parse_args(): parser = argparse.ArgumentParser(description="Run inference with a trained LightX2V model.") diff --git a/scripts/helios/run_helios_distilled_i2v.sh b/scripts/helios/run_helios_distilled_i2v.sh index 66d3b7146..d47e5a0c0 100644 --- a/scripts/helios/run_helios_distilled_i2v.sh +++ b/scripts/helios/run_helios_distilled_i2v.sh @@ -1,19 +1,28 @@ #!/bin/bash -lightx2v_path= -model_path=/data1/models/BestWishYSH/Helios-Distilled -image_path= +set -euo pipefail -export CUDA_VISIBLE_DEVICES=0 +script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +lightx2v_path=${LIGHTX2V_PATH:-$(cd "${script_dir}/../.." && pwd)} +model_path=${MODEL_PATH:-${1:-}} +image_path=${IMAGE_PATH:-${2:-}} -source ${lightx2v_path}/scripts/base/base.sh +if [[ -z "${model_path}" || -z "${image_path}" ]]; then + echo "Usage: MODEL_PATH=/path/to/Helios-Distilled IMAGE_PATH=/path/to/image $0" + echo " or: $0 /path/to/Helios-Distilled /path/to/image" + exit 1 +fi + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" + +source "${lightx2v_path}/scripts/base/base.sh" python -m lightx2v.infer \ --model_cls helios_distilled \ --task i2v \ ---model_path ${model_path} \ ---config_json ${lightx2v_path}/configs/helios/helios_distilled_i2v.json \ ---image_path ${image_path} \ +--model_path "${model_path}" \ +--config_json "${lightx2v_path}/configs/helios/helios_distilled_i2v.json" \ +--image_path "${image_path}" \ --prompt "The scene comes alive with subtle camera motion and realistic atmospheric movement." \ --negative_prompt "overexposed, blurry, low quality, jpeg artifacts, static frame, distorted anatomy, extra limbs" \ ---save_result_path ${lightx2v_path}/save_results/output_helios_distilled_i2v.mp4 +--save_result_path "${lightx2v_path}/save_results/output_helios_distilled_i2v.mp4" diff --git a/scripts/helios/run_helios_distilled_t2v.sh b/scripts/helios/run_helios_distilled_t2v.sh index b3d121c85..f7cde85f8 100644 --- a/scripts/helios/run_helios_distilled_t2v.sh +++ b/scripts/helios/run_helios_distilled_t2v.sh @@ -1,17 +1,26 @@ #!/bin/bash -lightx2v_path= -model_path=/data1/models/BestWishYSH/Helios-Distilled +set -euo pipefail -export CUDA_VISIBLE_DEVICES=0 +script_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +lightx2v_path=${LIGHTX2V_PATH:-$(cd "${script_dir}/../.." && pwd)} +model_path=${MODEL_PATH:-${1:-}} -source ${lightx2v_path}/scripts/base/base.sh +if [[ -z "${model_path}" ]]; then + echo "Usage: MODEL_PATH=/path/to/Helios-Distilled $0" + echo " or: $0 /path/to/Helios-Distilled" + exit 1 +fi + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0}" + +source "${lightx2v_path}/scripts/base/base.sh" python -m lightx2v.infer \ --model_cls helios_distilled \ --task t2v \ ---model_path ${model_path} \ ---config_json ${lightx2v_path}/configs/helios/helios_distilled_t2v.json \ +--model_path "${model_path}" \ +--config_json "${lightx2v_path}/configs/helios/helios_distilled_t2v.json" \ --prompt "A cinematic close-up of a snow leopard walking across a windy ridge at sunrise, detailed fur moving naturally in the light." \ --negative_prompt "overexposed, blurry, low quality, jpeg artifacts, static frame, distorted anatomy, extra limbs" \ ---save_result_path ${lightx2v_path}/save_results/output_helios_distilled_t2v.mp4 +--save_result_path "${lightx2v_path}/save_results/output_helios_distilled_t2v.mp4"