diff --git a/packages/ltx-trainer/src/ltx_trainer/config.py b/packages/ltx-trainer/src/ltx_trainer/config.py index 751f4fd6..220dd083 100644 --- a/packages/ltx-trainer/src/ltx_trainer/config.py +++ b/packages/ltx-trainer/src/ltx_trainer/config.py @@ -5,6 +5,7 @@ from ltx_trainer.quantization import QuantizationOptions from ltx_trainer.training_strategies.base_strategy import TrainingStrategyConfigBase +from ltx_trainer.training_strategies.keyframe_to_video import KeyframeToVideoConfig from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig @@ -89,7 +90,9 @@ def _get_strategy_discriminator(v: dict | TrainingStrategyConfigBase) -> str: # Union type for all strategy configs with discriminator TrainingStrategyConfig = Annotated[ - Annotated[TextToVideoConfig, Tag("text_to_video")] | Annotated[VideoToVideoConfig, Tag("video_to_video")], + Annotated[TextToVideoConfig, Tag("text_to_video")] + | Annotated[VideoToVideoConfig, Tag("video_to_video")] + | Annotated[KeyframeToVideoConfig, Tag("keyframe_to_video")], Discriminator(_get_strategy_discriminator), ] @@ -202,6 +205,18 @@ class ValidationConfig(ConfigBaseModel): "One image path must be provided for each validation prompt", ) + keyframe_images: list[list[str]] | None = Field( + default=None, + description="Per-prompt keyframe image paths for keyframe-to-video validation. " + "Each inner list contains the images for one validation prompt.", + ) + + keyframe_frame_indices: list[list[int]] | None = Field( + default=None, + description="Per-prompt pixel frame indices for keyframe_images. " + "Each inner list must match the corresponding keyframe_images entry.", + ) + reference_videos: list[str] | None = Field( default=None, description="List of reference video paths to use for validation. " @@ -320,6 +335,59 @@ def validate_images(cls, v: list[str] | None, info: ValidationInfo) -> list[str] return v + @field_validator("keyframe_images") + @classmethod + def validate_keyframe_images(cls, v: list[list[str]] | None, info: ValidationInfo) -> list[list[str]] | None: + """Validate that keyframe image groups match prompts and point to existing files.""" + if v is None: + return None + + num_prompts = len(info.data.get("prompts", [])) + if len(v) != num_prompts: + raise ValueError(f"Number of keyframe image groups ({len(v)}) must match number of prompts ({num_prompts})") + + for prompt_idx, keyframe_paths in enumerate(v): + if not keyframe_paths: + raise ValueError(f"Keyframe image group {prompt_idx} must contain at least one image") + for image_path in keyframe_paths: + if not Path(image_path).exists(): + raise ValueError(f"Keyframe image path '{image_path}' does not exist") + + return v + + @model_validator(mode="after") + def validate_keyframe_validation(self) -> "ValidationConfig": + """Validate keyframe image/index groups for keyframe-to-video validation.""" + if self.keyframe_images is None and self.keyframe_frame_indices is None: + return self + + if self.keyframe_images is None or self.keyframe_frame_indices is None: + raise ValueError("keyframe_images and keyframe_frame_indices must be provided together") + + if len(self.keyframe_images) != len(self.keyframe_frame_indices): + raise ValueError( + f"Number of keyframe image groups ({len(self.keyframe_images)}) must match " + f"number of keyframe index groups ({len(self.keyframe_frame_indices)})" + ) + + _width, _height, frames = self.video_dims + for prompt_idx, (keyframe_paths, frame_indices) in enumerate( + zip(self.keyframe_images, self.keyframe_frame_indices, strict=True) + ): + if len(keyframe_paths) != len(frame_indices): + raise ValueError( + f"Prompt {prompt_idx} has {len(keyframe_paths)} keyframe images but " + f"{len(frame_indices)} keyframe indices" + ) + for frame_idx in frame_indices: + if frame_idx < 0 or frame_idx >= frames: + raise ValueError( + f"Keyframe frame index {frame_idx} for prompt {prompt_idx} is outside " + f"the validation range [0, {frames - 1}]" + ) + + return self + @field_validator("reference_videos") @classmethod def validate_reference_videos(cls, v: list[str] | None, info: ValidationInfo) -> list[str] | None: diff --git a/packages/ltx-trainer/src/ltx_trainer/config_display.py b/packages/ltx-trainer/src/ltx_trainer/config_display.py index a80b1eb0..ebd6aab5 100644 --- a/packages/ltx-trainer/src/ltx_trainer/config_display.py +++ b/packages/ltx-trainer/src/ltx_trainer/config_display.py @@ -61,6 +61,10 @@ def fmt(v: object, max_len: int = 55) -> str: strategy_items.append(("Audio", fmt(cfg.training_strategy.with_audio))) if hasattr(cfg.training_strategy, "first_frame_conditioning_p"): strategy_items.append(("First Frame Cond P", str(cfg.training_strategy.first_frame_conditioning_p))) + if hasattr(cfg.training_strategy, "last_frame_conditioning_p"): + strategy_items.append(("Last Frame Cond P", str(cfg.training_strategy.last_frame_conditioning_p))) + if hasattr(cfg.training_strategy, "max_random_keyframes"): + strategy_items.append(("Max Random Keyframes", str(cfg.training_strategy.max_random_keyframes))) sections.append(("🎯 Strategy", strategy_items)) diff --git a/packages/ltx-trainer/src/ltx_trainer/trainer.py b/packages/ltx-trainer/src/ltx_trainer/trainer.py index dda3d607..b44042ec 100644 --- a/packages/ltx-trainer/src/ltx_trainer/trainer.py +++ b/packages/ltx-trainer/src/ltx_trainer/trainer.py @@ -455,7 +455,9 @@ def _load_models(self) -> None: # Check if we need VAE encoder (for image or reference video conditioning) need_vae_encoder = ( - self._config.validation.images is not None or self._config.validation.reference_videos is not None + self._config.validation.images is not None + or self._config.validation.keyframe_images is not None + or self._config.validation.reference_videos is not None ) # Load all model components (except text encoder - already handled) @@ -887,6 +889,7 @@ def _setup_accelerator(self) -> None: def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None: """Run validation by generating videos from validation prompts.""" use_images = self._config.validation.images is not None + use_keyframe_images = self._config.validation.keyframe_images is not None use_reference_videos = self._config.validation.reference_videos is not None generate_audio = self._config.validation.generate_audio inference_steps = self._config.validation.inference_steps @@ -930,6 +933,16 @@ def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None: # Convert PIL image to tensor [C, H, W] in [0, 1] condition_image = F.to_tensor(image) + # Load keyframe images if provided + keyframe_images = None + if use_keyframe_images: + keyframe_images = [] + keyframe_paths = self._config.validation.keyframe_images[prompt_idx] + keyframe_frame_indices = self._config.validation.keyframe_frame_indices[prompt_idx] + for keyframe_path, frame_idx in zip(keyframe_paths, keyframe_frame_indices, strict=True): + keyframe_image = open_image_as_srgb(keyframe_path) + keyframe_images.append((F.to_tensor(keyframe_image), frame_idx, 1.0)) + # Load reference video if provided (for IC-LoRA) reference_video = None if use_reference_videos: @@ -956,6 +969,7 @@ def _sample_videos(self, progress: TrainingProgress) -> list[Path] | None: guidance_scale=self._config.validation.guidance_scale, seed=self._config.validation.seed, condition_image=condition_image, + keyframe_images=keyframe_images, reference_video=reference_video, reference_downscale_factor=self._config.validation.reference_downscale_factor, generate_audio=generate_audio, diff --git a/packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py b/packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py index d62408f4..64afbfef 100644 --- a/packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py +++ b/packages/ltx-trainer/src/ltx_trainer/training_strategies/__init__.py @@ -2,6 +2,7 @@ This package implements the Strategy Pattern to handle different training modes: - Text-to-video training (standard generation, optionally with audio) - Video-to-video training (IC-LoRA mode with reference videos) +- Keyframe-to-video training (first/last keyframe interpolation) Each strategy encapsulates the specific logic for preparing model inputs and computing loss. """ @@ -13,16 +14,19 @@ TrainingStrategy, TrainingStrategyConfigBase, ) +from ltx_trainer.training_strategies.keyframe_to_video import KeyframeToVideoConfig, KeyframeToVideoStrategy from ltx_trainer.training_strategies.text_to_video import TextToVideoConfig, TextToVideoStrategy from ltx_trainer.training_strategies.video_to_video import VideoToVideoConfig, VideoToVideoStrategy # Type alias for all strategy config types -TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig +TrainingStrategyConfig = TextToVideoConfig | VideoToVideoConfig | KeyframeToVideoConfig __all__ = [ "DEFAULT_FPS", "VIDEO_SCALE_FACTORS", "ModelInputs", + "KeyframeToVideoConfig", + "KeyframeToVideoStrategy", "TextToVideoConfig", "TextToVideoStrategy", "TrainingStrategy", @@ -50,6 +54,8 @@ def get_training_strategy(config: TrainingStrategyConfig) -> TrainingStrategy: strategy = TextToVideoStrategy(config) case VideoToVideoConfig(): strategy = VideoToVideoStrategy(config) + case KeyframeToVideoConfig(): + strategy = KeyframeToVideoStrategy(config) case _: raise ValueError(f"Unknown training strategy config type: {type(config).__name__}") diff --git a/packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py b/packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py index c0ad7652..3bd15ed7 100644 --- a/packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py +++ b/packages/ltx-trainer/src/ltx_trainer/training_strategies/base_strategy.py @@ -34,7 +34,7 @@ class TrainingStrategyConfigBase(BaseModel): model_config = ConfigDict(extra="forbid") - name: Literal["text_to_video", "video_to_video"] = Field( + name: Literal["text_to_video", "video_to_video", "keyframe_to_video"] = Field( description="Unique name identifying the training strategy type" ) diff --git a/packages/ltx-trainer/src/ltx_trainer/training_strategies/keyframe_to_video.py b/packages/ltx-trainer/src/ltx_trainer/training_strategies/keyframe_to_video.py new file mode 100644 index 00000000..80e46f77 --- /dev/null +++ b/packages/ltx-trainer/src/ltx_trainer/training_strategies/keyframe_to_video.py @@ -0,0 +1,289 @@ +"""Keyframe-to-video training strategy. + +This strategy mirrors keyframe interpolation inference: +- Target video tokens are noised and trained normally. +- Clean first/last keyframe tokens are appended as conditioning tokens. +- Loss is computed only on target video tokens, not appended keyframes. +""" + +from typing import Any, Literal + +import torch +from pydantic import Field +from torch import Tensor + +from ltx_core.components.patchifiers import get_pixel_coords +from ltx_core.model.transformer.modality import Modality +from ltx_core.types import VideoLatentShape +from ltx_trainer import logger +from ltx_trainer.timestep_samplers import TimestepSampler +from ltx_trainer.training_strategies.base_strategy import ( + DEFAULT_FPS, + VIDEO_SCALE_FACTORS, + ModelInputs, + TrainingStrategy, + TrainingStrategyConfigBase, +) + + +class KeyframeToVideoConfig(TrainingStrategyConfigBase): + """Configuration for keyframe-to-video interpolation training.""" + + name: Literal["keyframe_to_video"] = "keyframe_to_video" + + first_frame_conditioning_p: float = Field( + default=1.0, + description="Batch-level probability of appending a clean first-frame keyframe condition", + ge=0.0, + le=1.0, + ) + + last_frame_conditioning_p: float = Field( + default=1.0, + description="Batch-level probability of appending a clean last-frame keyframe condition", + ge=0.0, + le=1.0, + ) + + random_keyframe_conditioning_p: float = Field( + default=0.0, + description="Batch-level probability of appending clean intermediate keyframe conditions", + ge=0.0, + le=1.0, + ) + + max_random_keyframes: int = Field( + default=0, + description="Maximum number of intermediate latent frames to append when random keyframe conditioning is used", + ge=0, + ) + + +class KeyframeToVideoStrategy(TrainingStrategy): + """Train generation conditioned on appended first and last keyframe tokens.""" + + config: KeyframeToVideoConfig + + def __init__(self, config: KeyframeToVideoConfig): + super().__init__(config) + + def get_data_sources(self) -> dict[str, str]: + """Keyframe training reuses standard video latents and text conditions.""" + return { + "latents": "latents", + "conditions": "conditions", + } + + def prepare_training_inputs( + self, + batch: dict[str, Any], + timestep_sampler: TimestepSampler, + ) -> ModelInputs: + """Prepare noised target tokens plus appended clean endpoint keyframe tokens.""" + latents = batch["latents"] + target_latents = latents["latents"] + + num_frames = latents["num_frames"][0].item() + height = latents["height"][0].item() + width = latents["width"][0].item() + + fps = latents.get("fps", None) + if fps is not None and not torch.all(fps == fps[0]): + logger.warning( + f"Different FPS values found in the batch. Found: {fps.tolist()}, using the first one: {fps[0].item()}" + ) + fps = fps[0].item() if fps is not None else DEFAULT_FPS + + target_tokens = self._video_patchifier.patchify(target_latents) + + conditions = batch["conditions"] + video_prompt_embeds = conditions["video_prompt_embeds"] + prompt_attention_mask = conditions["prompt_attention_mask"] + + batch_size, target_seq_len, _channels = target_tokens.shape + device = target_tokens.device + dtype = target_tokens.dtype + + sigmas = timestep_sampler.sample_for(target_tokens) + noise = torch.randn_like(target_tokens) + sigmas_expanded = sigmas.view(-1, 1, 1) + noisy_target = (1 - sigmas_expanded) * target_tokens + sigmas_expanded * noise + targets = noise - target_tokens + + target_positions = self._get_video_positions( + num_frames=num_frames, + height=height, + width=width, + batch_size=batch_size, + fps=fps, + device=device, + dtype=dtype, + ) + + latent_parts = [noisy_target] + positions_parts = [target_positions] + conditioning_masks = [torch.zeros(batch_size, target_seq_len, dtype=torch.bool, device=device)] + loss_masks = [torch.ones(batch_size, target_seq_len, dtype=torch.bool, device=device)] + + if self._sample_batch_condition(self.config.first_frame_conditioning_p, device): + self._append_keyframe( + target_latents=target_latents[:, :, :1], + frame_idx=0, + fps=fps, + latent_parts=latent_parts, + positions_parts=positions_parts, + conditioning_masks=conditioning_masks, + loss_masks=loss_masks, + ) + + if self._sample_batch_condition(self.config.last_frame_conditioning_p, device): + last_frame_idx = (num_frames - 1) * VIDEO_SCALE_FACTORS.time + self._append_keyframe( + target_latents=target_latents[:, :, -1:], + frame_idx=last_frame_idx, + fps=fps, + latent_parts=latent_parts, + positions_parts=positions_parts, + conditioning_masks=conditioning_masks, + loss_masks=loss_masks, + ) + + if self.config.max_random_keyframes > 0 and self._sample_batch_condition( + self.config.random_keyframe_conditioning_p, + device, + ): + interior_frame_indices = self._sample_interior_latent_frame_indices( + num_frames=num_frames, + max_keyframes=self.config.max_random_keyframes, + device=device, + ) + for latent_frame_idx in interior_frame_indices: + self._append_keyframe( + target_latents=target_latents[:, :, latent_frame_idx : latent_frame_idx + 1], + frame_idx=latent_frame_idx * VIDEO_SCALE_FACTORS.time, + fps=fps, + latent_parts=latent_parts, + positions_parts=positions_parts, + conditioning_masks=conditioning_masks, + loss_masks=loss_masks, + ) + + combined_latents = torch.cat(latent_parts, dim=1) + conditioning_mask = torch.cat(conditioning_masks, dim=1) + video_loss_mask = torch.cat(loss_masks, dim=1) + positions = torch.cat(positions_parts, dim=2) + timesteps = self._create_per_token_timesteps(conditioning_mask, sigmas.squeeze()) + + video_modality = Modality( + enabled=True, + sigma=sigmas, + latent=combined_latents, + timesteps=timesteps, + positions=positions, + context=video_prompt_embeds, + context_mask=prompt_attention_mask, + ) + + return ModelInputs( + video=video_modality, + audio=None, + video_targets=targets, + audio_targets=None, + video_loss_mask=video_loss_mask, + audio_loss_mask=None, + ) + + def compute_loss( + self, + video_pred: Tensor, + _audio_pred: Tensor | None, + inputs: ModelInputs, + ) -> Tensor: + """Compute masked MSE on target tokens only. Returns [B,].""" + target_seq_len = inputs.video_targets.shape[1] + target_pred = video_pred[:, :target_seq_len, :] + target_loss_mask = inputs.video_loss_mask[:, :target_seq_len] + + loss = (target_pred - inputs.video_targets).pow(2) + loss_mask = target_loss_mask.unsqueeze(-1).float() + masked = loss.mul(loss_mask) + return masked.mean(dim=[-2, -1]) / loss_mask.mean(dim=[-2, -1]).clamp(min=1e-8) + + @staticmethod + def _sample_batch_condition(probability: float, device: torch.device) -> bool: + if probability <= 0.0: + return False + if probability >= 1.0: + return True + return bool((torch.rand((), device=device) < probability).item()) + + @staticmethod + def _sample_interior_latent_frame_indices( + num_frames: int, + max_keyframes: int, + device: torch.device, + ) -> list[int]: + interior_count = max(0, num_frames - 2) + if interior_count == 0 or max_keyframes == 0: + return [] + + keyframe_count = min(interior_count, max_keyframes) + indices = torch.randperm(interior_count, device=device)[:keyframe_count] + 1 + return sorted(int(idx.item()) for idx in indices) + + def _append_keyframe( + self, + target_latents: Tensor, + frame_idx: int, + fps: float, + latent_parts: list[Tensor], + positions_parts: list[Tensor], + conditioning_masks: list[Tensor], + loss_masks: list[Tensor], + ) -> None: + keyframe_tokens = self._video_patchifier.patchify(target_latents) + keyframe_positions = self._get_keyframe_positions( + keyframes=target_latents, + frame_idx=frame_idx, + fps=fps, + ) + + batch_size, keyframe_seq_len, _channels = keyframe_tokens.shape + keyframe_conditioning_mask = torch.ones( + batch_size, + keyframe_seq_len, + dtype=torch.bool, + device=keyframe_tokens.device, + ) + keyframe_loss_mask = torch.zeros( + batch_size, + keyframe_seq_len, + dtype=torch.bool, + device=keyframe_tokens.device, + ) + + latent_parts.append(keyframe_tokens) + positions_parts.append(keyframe_positions) + conditioning_masks.append(keyframe_conditioning_mask) + loss_masks.append(keyframe_loss_mask) + + def _get_keyframe_positions( + self, + keyframes: Tensor, + frame_idx: int, + fps: float, + ) -> Tensor: + latent_coords = self._video_patchifier.get_patch_grid_bounds( + output_shape=VideoLatentShape.from_torch_shape(keyframes.shape), + device=keyframes.device, + ) + positions = get_pixel_coords( + latent_coords=latent_coords, + scale_factors=VIDEO_SCALE_FACTORS, + causal_fix=frame_idx == 0, + ) + positions[:, 0, ...] += frame_idx + positions[:, 0, ..., 1:] = positions[:, 0, ..., :1] + 1 + positions = positions.to(dtype=torch.float32) + positions[:, 0, ...] /= fps + return positions.to(dtype=keyframes.dtype) diff --git a/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py b/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py index b756cf14..4f213401 100644 --- a/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py +++ b/packages/ltx-trainer/src/ltx_trainer/validation_sampler.py @@ -19,6 +19,7 @@ get_pixel_coords, ) from ltx_core.components.schedulers import LTX2Scheduler +from ltx_core.conditioning import VideoConditionByKeyframeIndex from ltx_core.guidance.perturbations import ( BatchedPerturbationConfig, Perturbation, @@ -85,6 +86,7 @@ class GenerationConfig: guidance_scale: float = 4.0 # CFG guidance scale seed: int = 42 # Random seed for reproducibility condition_image: Tensor | None = None # Optional first frame image for image-to-video + keyframe_images: list[tuple[Tensor, int, float]] | None = None # Optional keyframe images for interpolation reference_video: Tensor | None = None # For IC-LoRA: [F, C, H, W] in [0, 1] reference_downscale_factor: int = 1 # For IC-LoRA: downscale factor (1 = same resolution, 2 = half resolution) generate_audio: bool = True # Whether to generate audio alongside video @@ -203,6 +205,15 @@ def _generate_standard(self, config: GenerationConfig, device: torch.device) -> video_clean_state, config.condition_image, config, device ) + if config.keyframe_images is not None: + video_clean_state = self._apply_keyframe_conditioning( + video_clean_state, + config.keyframe_images, + config, + video_tools, + device, + ) + # Add noise noiser = GaussianNoiser(generator=generator) video_state = noiser(latent_state=video_clean_state, noise_scale=1.0) @@ -385,6 +396,27 @@ def _apply_image_conditioning( clean_latent=new_clean_latent, ) + def _apply_keyframe_conditioning( + self, + video_state: LatentState, + keyframe_images: list[tuple[Tensor, int, float]], + config: GenerationConfig, + video_tools: VideoLatentTools, + device: torch.device, + ) -> LatentState: + """Append keyframe image latents using the same guiding-token path as keyframe interpolation.""" + for image, frame_idx, strength in keyframe_images: + encoded_image = self._encode_conditioning_image(image, config.height, config.width, device) + keyframe_cond = VideoConditionByKeyframeIndex( + keyframes=encoded_image, + frame_idx=frame_idx, + strength=strength, + num_pixel_frames=1, + ) + video_state = keyframe_cond.apply_to(video_state, video_tools) + + return video_state + @staticmethod def _preprocess_reference_video(config: GenerationConfig) -> Tensor: """Preprocess reference video: resize, crop, and convert to model input format. @@ -675,6 +707,8 @@ def _validate_config(self, config: GenerationConfig) -> None: raise ValueError("Audio generation requires audio_decoder and vocoder") if config.condition_image is not None and self._vae_encoder is None: raise ValueError("Image conditioning requires vae_encoder") + if config.keyframe_images is not None and self._vae_encoder is None: + raise ValueError("Keyframe image conditioning requires vae_encoder") if config.reference_video is not None and self._vae_encoder is None: raise ValueError("Reference video conditioning requires vae_encoder") diff --git a/packages/ltx-trainer/tests/test_keyframe_to_video_strategy.py b/packages/ltx-trainer/tests/test_keyframe_to_video_strategy.py new file mode 100644 index 00000000..2c485b06 --- /dev/null +++ b/packages/ltx-trainer/tests/test_keyframe_to_video_strategy.py @@ -0,0 +1,103 @@ +import unittest + +import torch + +from ltx_trainer.training_strategies.keyframe_to_video import ( + KeyframeToVideoConfig, + KeyframeToVideoStrategy, +) + + +class FixedTimestepSampler: + def __init__(self, sigma: float = 0.5) -> None: + self.sigma = sigma + + def sample_for(self, latents: torch.Tensor) -> torch.Tensor: + return torch.full((latents.shape[0], 1, 1), self.sigma, device=latents.device, dtype=latents.dtype) + + +class KeyframeToVideoStrategyTest(unittest.TestCase): + def test_appends_clean_first_and_last_keyframe_tokens(self) -> None: + torch.manual_seed(0) + strategy = KeyframeToVideoStrategy(KeyframeToVideoConfig()) + batch = self._batch(batch_size=2, channels=4, frames=3, height=2, width=3) + + inputs = strategy.prepare_training_inputs(batch, FixedTimestepSampler(sigma=0.25)) + + target_tokens = 3 * 2 * 3 + keyframe_tokens = 2 * 3 + first_keyframe_start = target_tokens + last_keyframe_start = target_tokens + keyframe_tokens + + expected_first = strategy._video_patchifier.patchify(batch["latents"]["latents"][:, :, :1]) + expected_last = strategy._video_patchifier.patchify(batch["latents"]["latents"][:, :, -1:]) + + self.assertEqual(inputs.video.latent.shape, (2, target_tokens + 2 * keyframe_tokens, 4)) + torch.testing.assert_close( + inputs.video.latent[:, first_keyframe_start:last_keyframe_start], + expected_first, + ) + torch.testing.assert_close( + inputs.video.latent[:, last_keyframe_start:], + expected_last, + ) + + self.assertTrue(torch.all(inputs.video.timesteps[:, :target_tokens] == 0.25)) + self.assertTrue(torch.all(inputs.video.timesteps[:, target_tokens:] == 0.0)) + + self.assertTrue(torch.all(inputs.video_loss_mask[:, :target_tokens])) + self.assertFalse(torch.any(inputs.video_loss_mask[:, target_tokens:])) + + # A 3-frame latent sequence corresponds to pixel keyframes at frame 0 and frame 16. + last_keyframe_positions = inputs.video.positions[:, :, last_keyframe_start:] + torch.testing.assert_close( + last_keyframe_positions[:, 0, :, 0], + torch.full((2, keyframe_tokens), 16 / 24, dtype=last_keyframe_positions.dtype), + ) + torch.testing.assert_close( + last_keyframe_positions[:, 0, :, 1], + torch.full((2, keyframe_tokens), 17 / 24, dtype=last_keyframe_positions.dtype), + ) + + def test_loss_ignores_appended_keyframe_tokens(self) -> None: + torch.manual_seed(0) + strategy = KeyframeToVideoStrategy(KeyframeToVideoConfig()) + batch = self._batch(batch_size=2, channels=4, frames=3, height=2, width=3) + + inputs = strategy.prepare_training_inputs(batch, FixedTimestepSampler(sigma=0.25)) + + target_tokens = inputs.video_targets.shape[1] + video_pred = torch.zeros_like(inputs.video.latent) + video_pred[:, target_tokens:] = 1_000_000.0 + + expected = inputs.video_targets.pow(2).mean(dim=[-2, -1]) + torch.testing.assert_close(strategy.compute_loss(video_pred, None, inputs), expected) + + @staticmethod + def _batch( + batch_size: int, + channels: int, + frames: int, + height: int, + width: int, + ) -> dict[str, dict[str, torch.Tensor]]: + values = torch.arange(batch_size * channels * frames * height * width, dtype=torch.float32) + latents = values.reshape(batch_size, channels, frames, height, width) + return { + "latents": { + "latents": latents, + "num_frames": torch.full((batch_size,), frames), + "height": torch.full((batch_size,), height), + "width": torch.full((batch_size,), width), + "fps": torch.full((batch_size,), 24.0), + }, + "conditions": { + "video_prompt_embeds": torch.randn(batch_size, 5, 8), + "audio_prompt_embeds": torch.randn(batch_size, 5, 8), + "prompt_attention_mask": torch.ones(batch_size, 5, dtype=torch.bool), + }, + } + + +if __name__ == "__main__": + unittest.main()