Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion packages/ltx-trainer/src/ltx_trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ltx_trainer.quantization import QuantizationOptions
from ltx_trainer.training_strategies.base_strategy import TrainingStrategyConfigBase
from ltx_trainer.training_strategies.keyframe_to_video import KeyframeToVideoConfig
from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig
from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig

Expand Down Expand Up @@ -89,7 +90,9 @@ def _get_strategy_discriminator(v: dict | TrainingStrategyConfigBase) -> str:

# Union type for all strategy configs with discriminator
TrainingStrategyConfig = Annotated[
Annotated[TextToVideoConfig, Tag("text_to_video")] | Annotated[VideoToVideoConfig, Tag("video_to_video")],
Annotated[TextToVideoConfig, Tag("text_to_video")]
| Annotated[VideoToVideoConfig, Tag("video_to_video")]
| Annotated[KeyframeToVideoConfig, Tag("keyframe_to_video")],
Discriminator(_get_strategy_discriminator),
]

Expand Down Expand Up @@ -202,6 +205,18 @@ class ValidationConfig(ConfigBaseModel):
"One image path must be provided for each validation prompt",
)

keyframe_images: list[list[str]] | None = Field(
default=None,
description="Per-prompt keyframe image paths for keyframe-to-video validation. "
"Each inner list contains the images for one validation prompt.",
)

keyframe_frame_indices: list[list[int]] | None = Field(
default=None,
description="Per-prompt pixel frame indices for keyframe_images. "
"Each inner list must match the corresponding keyframe_images entry.",
)

reference_videos: list[str] | None = Field(
default=None,
description="List of reference video paths to use for validation. "
Expand Down Expand Up @@ -320,6 +335,59 @@ def validate_images(cls, v: list[str] | None, info: ValidationInfo) -> list[str]

return v

@field_validator("keyframe_images")
@classmethod
def validate_keyframe_images(cls, v: list[list[str]] | None, info: ValidationInfo) -> list[list[str]] | None:
"""Validate that keyframe image groups match prompts and point to existing files."""
if v is None:
return None

num_prompts = len(info.data.get("prompts", []))
if len(v) != num_prompts:
raise ValueError(f"Number of keyframe image groups ({len(v)}) must match number of prompts ({num_prompts})")

for prompt_idx, keyframe_paths in enumerate(v):
if not keyframe_paths:
raise ValueError(f"Keyframe image group {prompt_idx} must contain at least one image")
for image_path in keyframe_paths:
if not Path(image_path).exists():
raise ValueError(f"Keyframe image path '{image_path}' does not exist")

return v

@model_validator(mode="after")
def validate_keyframe_validation(self) -> "ValidationConfig":
"""Validate keyframe image/index groups for keyframe-to-video validation."""
if self.keyframe_images is None and self.keyframe_frame_indices is None:
return self

if self.keyframe_images is None or self.keyframe_frame_indices is None:
raise ValueError("keyframe_images and keyframe_frame_indices must be provided together")

if len(self.keyframe_images) != len(self.keyframe_frame_indices):
raise ValueError(
f"Number of keyframe image groups ({len(self.keyframe_images)}) must match "
f"number of keyframe index groups ({len(self.keyframe_frame_indices)})"
)

_width, _height, frames = self.video_dims
for prompt_idx, (keyframe_paths, frame_indices) in enumerate(
zip(self.keyframe_images, self.keyframe_frame_indices, strict=True)
):
if len(keyframe_paths) != len(frame_indices):
raise ValueError(
f"Prompt {prompt_idx} has {len(keyframe_paths)} keyframe images but "
f"{len(frame_indices)} keyframe indices"
)
for frame_idx in frame_indices:
if frame_idx < 0 or frame_idx >= frames:
raise ValueError(
f"Keyframe frame index {frame_idx} for prompt {prompt_idx} is outside "
f"the validation range [0, {frames - 1}]"
)

return self

@field_validator("reference_videos")
@classmethod
def validate_reference_videos(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None:
Expand Down
4 changes: 4 additions & 0 deletions packages/ltx-trainer/src/ltx_trainer/config_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def fmt(v: object, max_len: int = 55) -> str:
strategy_items.append(("Audio", fmt(cfg.training_strategy.with_audio)))
if hasattr(cfg.training_strategy, "first_frame_conditioning_p"):
strategy_items.append(("First Frame Cond P", str(cfg.training_strategy.first_frame_conditioning_p)))
if hasattr(cfg.training_strategy, "last_frame_conditioning_p"):
strategy_items.append(("Last Frame Cond P", str(cfg.training_strategy.last_frame_conditioning_p)))
if hasattr(cfg.training_strategy, "max_random_keyframes"):
strategy_items.append(("Max Random Keyframes", str(cfg.training_strategy.max_random_keyframes)))

sections.append(("🎯 Strategy", strategy_items))

Expand Down
16 changes: 15 additions & 1 deletion packages/ltx-trainer/src/ltx_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,9 @@ def _load_models(self) -> None:

# Check if we need VAE encoder (for image or reference video conditioning)
need_vae_encoder = (
self._config.validation.images is not None or self._config.validation.reference_videos is not None
self._config.validation.images is not None
or self._config.validation.keyframe_images is not None
or self._config.validation.reference_videos is not None
)

# Load all model components (except text encoder - already handled)
Expand Down Expand Up @@ -887,6 +889,7 @@ def _setup_accelerator(self) -> None:
def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None:
"""Run validation by generating videos from validation prompts."""
use_images = self._config.validation.images is not None
use_keyframe_images = self._config.validation.keyframe_images is not None
use_reference_videos = self._config.validation.reference_videos is not None
generate_audio = self._config.validation.generate_audio
inference_steps = self._config.validation.inference_steps
Expand Down Expand Up @@ -930,6 +933,16 @@ def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None:
# Convert PIL image to tensor [C, H, W] in [0, 1]
condition_image = F.to_tensor(image)

# Load keyframe images if provided
keyframe_images = None
if use_keyframe_images:
keyframe_images = []
keyframe_paths = self._config.validation.keyframe_images[prompt_idx]
keyframe_frame_indices = self._config.validation.keyframe_frame_indices[prompt_idx]
for keyframe_path, frame_idx in zip(keyframe_paths, keyframe_frame_indices, strict=True):
keyframe_image = open_image_as_srgb(keyframe_path)
keyframe_images.append((F.to_tensor(keyframe_image), frame_idx, 1.0))

# Load reference video if provided (for IC-LoRA)
reference_video = None
if use_reference_videos:
Expand All @@ -956,6 +969,7 @@ def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None:
guidance_scale=self._config.validation.guidance_scale,
seed=self._config.validation.seed,
condition_image=condition_image,
keyframe_images=keyframe_images,
reference_video=reference_video,
reference_downscale_factor=self._config.validation.reference_downscale_factor,
generate_audio=generate_audio,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This package implements the Strategy Pattern to handle different training modes:
- Text-to-video training (standard generation, optionally with audio)
- Video-to-video training (IC-LoRA mode with reference videos)
- Keyframe-to-video training (first/last keyframe interpolation)
Each strategy encapsulates the specific logic for preparing model inputs and computing loss.
"""

Expand All @@ -13,16 +14,19 @@
TrainingStrategy,
TrainingStrategyConfigBase,
)
from ltx_trainer.training_strategies.keyframe_to_video import KeyframeToVideoConfig, KeyframeToVideoStrategy
from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig, TextToVideoStrategy
from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig, VideoToVideoStrategy

# Type alias for all strategy config types
TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig
TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig | KeyframeToVideoConfig

__all__ = [
"DEFAULT_FPS",
"VIDEO_SCALE_FACTORS",
"ModelInputs",
"KeyframeToVideoConfig",
"KeyframeToVideoStrategy",
"TextToVideoConfig",
"TextToVideoStrategy",
"TrainingStrategy",
Expand Down Expand Up @@ -50,6 +54,8 @@ def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy:
strategy = TextToVideoStrategy(config)
case VideoToVideoConfig():
strategy = VideoToVideoStrategy(config)
case KeyframeToVideoConfig():
strategy = KeyframeToVideoStrategy(config)
case _:
raise ValueError(f"Unknown training strategy config type: {type(config).__name__}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TrainingStrategyConfigBase(BaseModel):

model_config = ConfigDict(extra="forbid")

name: Literal["text_to_video", "video_to_video"] = Field(
name: Literal["text_to_video", "video_to_video", "keyframe_to_video"] = Field(
description="Unique name identifying the training strategy type"
)

Expand Down
Loading