Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions configs/helios/helios_distilled_i2v.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"model_cls": "helios_distilled",
"model_variant": "distilled",
"infer_steps": 6,
"target_video_length": 99,
"target_height": 384,
"target_width": 640,
"sample_guide_scale": 1.0,
"enable_cfg": false,
"fps": 24,
"history_sizes": [16, 2, 1],
"num_latent_frames_per_chunk": 9,
"use_zero_init": false,
"zero_steps": 1,
"is_enable_stage2": false,
"pyramid_num_inference_steps_list": [2, 2, 2],
"is_skip_first_chunk": false,
"is_amplify_first_chunk": false,
"image_noise_sigma_min": 0.111,
"image_noise_sigma_max": 0.135
}
19 changes: 19 additions & 0 deletions configs/helios/helios_distilled_t2v.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"model_cls": "helios_distilled",
"model_variant": "distilled",
"infer_steps": 6,
"target_video_length": 99,
"target_height": 384,
"target_width": 640,
"sample_guide_scale": 1.0,
"enable_cfg": false,
"fps": 24,
"history_sizes": [16, 2, 1],
"num_latent_frames_per_chunk": 9,
"use_zero_init": false,
"zero_steps": 1,
"is_enable_stage2": false,
"pyramid_num_inference_steps_list": [2, 2, 2],
"is_skip_first_chunk": false,
"is_amplify_first_chunk": false
}
79 changes: 42 additions & 37 deletions lightx2v/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from lightx2v.common.ops import *
from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401
from lightx2v.models.runners.helios.helios_runner import HeliosRunner # noqa: F401
from lightx2v.models.runners.hunyuan3d.hunyuan3d_shape_runner import Hunyuan3DShapeRunner # noqa: F401

# from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401
Expand Down Expand Up @@ -39,6 +40,45 @@
from lightx2v.utils.utils import seed_all, validate_config_paths
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER

SUPPORTED_MODEL_CLASSES = [
"wan2.1",
"wan2.1_distill",
"wan2.1_mean_flow_distill",
"wan2.1_vace",
"wan2.1_sf",
"wan2.1_sf_mtxg2",
"seko_talk",
"seko_talk_ar",
"wan2.2_moe",
"lingbot_world",
"wan2.2",
"wan2.2_matrix_game3",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2_moe_distill",
"wan2.2_moe_vace",
"qwen_image",
"longcat_image",
"wan2.2_animate",
"hunyuan_video_1.5",
"hunyuan_video_1.5_distill",
"helios_distilled",
"hunyuan3d",
"worldplay_distill",
"worldplay_ar",
"worldplay_bi",
"z_image",
"flux2_klein",
"flux2_dev",
"ltx2",
"bagel",
"seedvr2",
"neopp",
"motus",
"lingbot_world_fast",
"worldmirror",
]


def init_runner(config):
torch.set_grad_enabled(False)
Expand All @@ -54,43 +94,6 @@ def main():
"--model_cls",
type=str,
required=True,
choices=[
"wan2.1",
"wan2.1_distill",
"wan2.1_mean_flow_distill",
"wan2.1_vace",
"wan2.1_sf",
"wan2.1_sf_mtxg2",
"seko_talk",
"seko_talk_ar",
"wan2.2_moe",
"lingbot_world",
"wan2.2",
"wan2.2_matrix_game3",
"wan2.2_moe_audio",
"wan2.2_audio",
"wan2.2_moe_distill",
"wan2.2_moe_vace",
"qwen_image",
"longcat_image",
"wan2.2_animate",
"hunyuan_video_1.5",
"hunyuan_video_1.5_distill",
"hunyuan3d",
"worldplay_distill",
"worldplay_ar",
"worldplay_bi",
"z_image",
"flux2_klein",
"flux2_dev",
"ltx2",
"bagel",
"seedvr2",
"neopp",
"motus",
"lingbot_world_fast",
"worldmirror",
],
default="wan2.1",
)

Expand Down Expand Up @@ -221,6 +224,8 @@ def main():
parser.add_argument("--mux_audio_video_path", type=str, default=None, help="(v2av, optional) After saving, mux audio from this file into the output mp4 (ffmpeg). ")

args = parser.parse_args()
if args.model_cls not in SUPPORTED_MODEL_CLASSES:
parser.error(f"invalid --model_cls '{args.model_cls}'. Supported values: {', '.join(SUPPORTED_MODEL_CLASSES)}")
# validate_task_arguments(args)

seed_all(args.seed)
Expand Down
3 changes: 3 additions & 0 deletions lightx2v/models/input_encoders/hf/helios/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from lightx2v.models.input_encoders.hf.helios.model import HeliosTextEncoder

__all__ = ["HeliosTextEncoder"]
79 changes: 79 additions & 0 deletions lightx2v/models/input_encoders/hf/helios/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import html

import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

from lightx2v.utils.envs import GET_DTYPE
from lightx2v_platform.base.global_var import AI_DEVICE

try:
import ftfy
except ImportError:
ftfy = None


def basic_clean(text):
if ftfy is not None:
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()


def whitespace_clean(text):
return re.sub(r"\s+", " ", text).strip()


def prompt_clean(text):
return whitespace_clean(basic_clean(text))


def pack_t5_prompt_embeds(hidden_state, attention_mask, max_sequence_length, num_videos_per_prompt=1, dtype=None, device=None):
device = device or hidden_state.device
dtype = dtype or hidden_state.dtype
prompt_embeds = hidden_state.to(dtype=dtype, device=device)
attention_mask = attention_mask.to(device=device)
seq_lens = attention_mask.gt(0).sum(dim=1).long()
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds],
dim=0,
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(len(seq_lens) * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, attention_mask.bool()


class HeliosTextEncoder:
def __init__(self, config):
self.config = config
use_cpu = config.get("text_encoder_cpu_offload", config.get("t5_cpu_offload", config.get("cpu_offload", False)))
self.device = torch.device("cpu") if use_cpu else torch.device(AI_DEVICE)
self.dtype = GET_DTYPE()
self.tokenizer = AutoTokenizer.from_pretrained(config["tokenizer_path"])
self.text_encoder = UMT5EncoderModel.from_pretrained(config["text_encoder_path"], torch_dtype=self.dtype).to(self.device)

def infer(self, prompts, max_sequence_length=None):
max_sequence_length = max_sequence_length or self.config.get("max_sequence_length", 512)
prompts = [prompt_clean(prompt) for prompt in prompts]
text_inputs = self.tokenizer(
prompts,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
input_ids = text_inputs.input_ids.to(self.device)
attention_mask = text_inputs.attention_mask.to(self.device)
hidden_state = self.text_encoder(input_ids, attention_mask).last_hidden_state
return pack_t5_prompt_embeds(
hidden_state,
attention_mask,
max_sequence_length=max_sequence_length,
num_videos_per_prompt=1,
dtype=self.dtype,
device=torch.device(AI_DEVICE),
)
3 changes: 3 additions & 0 deletions lightx2v/models/networks/helios/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from lightx2v.models.networks.helios.model import HeliosModel

__all__ = ["HeliosModel"]
95 changes: 95 additions & 0 deletions lightx2v/models/networks/helios/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os

import torch
from loguru import logger

from lightx2v.models.networks.helios.transformer_helios import HeliosTransformer3DModel
from lightx2v.utils.envs import GET_DTYPE


class HeliosModel:
def __init__(self, model_path, config, device):
self.config = config
self.device = device
transformer_path = config.get("transformer_model_path") or model_path
self.transformer = HeliosTransformer3DModel.from_pretrained(
transformer_path,
subfolder=None if os.path.basename(transformer_path) == "transformer" else "transformer",
torch_dtype=GET_DTYPE(),
).to(device)
self.scheduler = None
self._set_attention_backend()

def _set_attention_backend(self):
attn_type = self.config.get("attn_type")
if not attn_type:
return
try:
if attn_type == "flash_attn3":
self.transformer.set_attention_backend("_flash_3_hub")
elif attn_type == "flash_attn2":
self.transformer.set_attention_backend("flash_hub")
elif attn_type == "torch_sdpa":
self.transformer.set_attention_backend("sdpa")
except Exception as exc:
logger.warning(f"Failed to set Helios attention backend {attn_type}: {exc}")

def set_scheduler(self, scheduler):
self.scheduler = scheduler

@property
def dtype(self):
return self.transformer.dtype

def infer_noise(
self,
latents,
timestep,
encoder_hidden_states,
history_inputs,
attention_kwargs=None,
):
return self.transformer(
hidden_states=latents,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
attention_kwargs=attention_kwargs,
return_dict=False,
**history_inputs,
)[0]

def infer_cfg(
self,
latents,
timestep,
prompt_embeds,
negative_prompt_embeds,
history_inputs,
guidance_scale,
attention_kwargs=None,
is_cfg_zero_star=False,
use_zero_init=False,
zero_steps=1,
stage_idx=0,
step_idx=0,
):
with self.transformer.cache_context("cond"):
noise_pred = self.infer_noise(latents, timestep, prompt_embeds, history_inputs, attention_kwargs)

if guidance_scale <= 1.0 or negative_prompt_embeds is None:
return noise_pred

with self.transformer.cache_context("uncond"):
noise_uncond = self.infer_noise(latents, timestep, negative_prompt_embeds, history_inputs, attention_kwargs)

if is_cfg_zero_star:
positive_flat = noise_pred.view(noise_pred.shape[0], -1).float()
negative_flat = noise_uncond.view(noise_uncond.shape[0], -1).float()
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
alpha = (dot_product / squared_norm).view(noise_pred.shape[0], *([1] * (noise_pred.ndim - 1))).to(noise_pred.dtype)
if stage_idx == 0 and step_idx <= zero_steps and use_zero_init:
return noise_pred * 0.0
return noise_uncond * alpha + guidance_scale * (noise_pred - noise_uncond * alpha)

return noise_uncond + guidance_scale * (noise_pred - noise_uncond)
Loading
Loading