Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion packages/ltx-trainer/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ dependencies = [
dev = [
"pre-commit>=4.0.1",
"ruff>=0.8.6",
"pytest>=8.0",
"pytest-cov>=5.0",
]


Expand All @@ -47,8 +49,11 @@ build-backend = "hatchling.build"



[tool.pytest.ini_options]
testpaths = ["tests"]

[tool.ruff]
target-version = "1.1.3"
target-version = "py310"
line-length = 120

[tool.ruff.lint]
Expand Down Expand Up @@ -87,3 +92,6 @@ ignore = [
max-args = 10
[tool.ruff.lint.isort]
known-first-party = ["ltx_trainer", "ltx_core", "ltx_pipelines"]

[tool.ruff.lint.per-file-ignores]
"tests/**" = ["ANN", "T20"]
156 changes: 112 additions & 44 deletions packages/ltx-trainer/scripts/process_videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ def compute_latents( # noqa: PLR0913, PLR0915
batch_size: int = 1,
device: str = "cuda",
vae_tiling: bool = False,
tile_batch_size: int = 1,
with_audio: bool = False,
audio_output_dir: str | None = None,
overwrite: bool = False,
Expand All @@ -466,6 +467,7 @@ def compute_latents( # noqa: PLR0913, PLR0915
batch_size: Batch size for processing
device: Device to use for computation
vae_tiling: Whether to enable VAE tiling
tile_batch_size: Number of same-shape tiles per VAE forward pass (only used when vae_tiling=True)
with_audio: Whether to extract and encode audio from videos
audio_output_dir: Directory to save audio latents (required if with_audio=True)
overwrite: Re-process every item even if its output exists. Use when rerunning with
Expand Down Expand Up @@ -562,7 +564,9 @@ def _is_done(idx: int) -> bool:

# Encode video
with torch.inference_mode():
video_latent_data = encode_video(vae=vae, video=video, use_tiling=vae_tiling)
video_latent_data = encode_video(
vae=vae, video=video, use_tiling=vae_tiling, tile_batch_size=tile_batch_size
)

# Save latents for each item in batch
for i in range(len(batch["relative_path"])):
Expand Down Expand Up @@ -632,6 +636,7 @@ def encode_video(
use_tiling: bool = False,
tile_size: int = DEFAULT_TILE_SIZE,
tile_overlap: int = DEFAULT_TILE_OVERLAP,
tile_batch_size: int = 1,
) -> dict[str, torch.Tensor | int]:
"""Encode video into non-patchified latent representation.
Args:
Expand All @@ -642,6 +647,9 @@ def encode_video(
use_tiling: Whether to use spatial tiling for memory efficiency
tile_size: Tile size in pixels (must be divisible by 32)
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
tile_batch_size: Number of same-shape tiles to encode in a single VAE forward pass.
Only used when ``use_tiling=True``. See ``tiled_encode_video`` for the full
description of the two-phase mini-batch approach and the VRAM/throughput trade-off.
Returns:
Dict containing non-patchified latents and shape information:
{
Expand All @@ -662,11 +670,13 @@ def encode_video(

# Choose encoding method based on tiling flag
if use_tiling:
# Keep kwargs in sync with tiled_encode_video signature
latents = tiled_encode_video(
vae=vae,
video=video,
tile_size=tile_size,
tile_overlap=tile_overlap,
tile_batch_size=tile_batch_size,
)
else:
# Encode video - VAE expects [B, C, F, H, W], returns [B, C, F', H', W']
Expand All @@ -690,17 +700,43 @@ def tiled_encode_video( # noqa: PLR0912, PLR0915
video: torch.Tensor,
tile_size: int = DEFAULT_TILE_SIZE,
tile_overlap: int = DEFAULT_TILE_OVERLAP,
tile_batch_size: int = 1,
) -> torch.Tensor:
"""Encode video using spatial tiling for memory efficiency.

Splits the video into overlapping spatial tiles, encodes each tile separately,
and blends the results using linear feathering in the overlap regions.

Two-phase approach: Phase 1 collects tile positions grouped by pixel shape (tiles of
the same shape can share a VAE call). Phase 2 encodes each shape group in sub-lists
of ``tile_batch_size`` tiles per forward pass, then blends results into the output.

``tile_batch_size`` controls peak VRAM per VAE call: 1 (default) = lowest memory;
larger values amortise CUDA launch overhead at the cost of more tile activations
in memory simultaneously. ``tile_batch_size=1`` preserves accumulation order within
each shape group and produces output within float32 rounding tolerance of the original
sequential loop (shape-group ordering may differ from strict row-major for videos with
edge tiles).

Approximate tile counts per resolution at default tile_size=512, tile_overlap=128:
512x512 -> 1 tile (fast path: direct VAE call, no tiling)
768x768 -> 4 tiles (2x2 grid)
896x896 -> 4 tiles (2x2 grid)
1024x1024 -> 9 tiles (3x3 grid)
1280x1280 -> 16 tiles (4x4 grid)

Args:
vae: Video VAE encoder model
video: Input tensor of shape [B, C, F, H, W]
tile_size: Tile size in pixels (must be divisible by 32)
tile_overlap: Overlap between tiles in pixels (must be divisible by 32)
tile_batch_size: Number of same-shape tiles per VAE forward pass. See above for trade-offs.
Returns:
Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent]

Note:
Requires the VAE to be batch-agnostic (e.g., Conv3d-only, or BatchNorm in eval mode);
batch-statistic layers in training mode will produce different latents for tile_batch_size > 1.
"""
batch, _channels, frames, height, width = video.shape
device = video.device
Expand All @@ -713,8 +749,10 @@ def tiled_encode_video( # noqa: PLR0912, PLR0915
raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}")
if tile_overlap >= tile_size:
raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})")
if tile_batch_size < 1:
raise ValueError(f"tile_batch_size must be >= 1, got {tile_batch_size}")

# If video fits in a single tile, use regular encoding
# If video fits in a single tile, use regular encoding (tile_batch_size has no effect here)
if height <= tile_size and width <= tile_size:
return vae(video)

Expand Down Expand Up @@ -762,77 +800,102 @@ def tiled_encode_video( # noqa: PLR0912, PLR0915
overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR
overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR

# Process each tile
fade_in_h = None
fade_out_h = None
fade_in_w = None
fade_out_w = None
if overlap_out_h > 0:
fade_in_h = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
fade_out_h = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
if overlap_out_w > 0:
fade_in_w = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
fade_out_w = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]

# Phase 1: collect tile positions grouped by (tile_h, tile_w).
# Same-shape tiles can be cat-stacked for a single VAE call; cross-shape tiles cannot.
tiles_by_shape: dict[tuple[int, int], list[dict[str, int]]] = {}
for h_pos in h_positions:
for w_pos in w_positions:
# Calculate tile boundaries in input space
h_start = max(0, h_pos)
w_start = max(0, w_pos)
h_end = min(h_start + tile_size, height)
w_end = min(w_start + tile_size, width)

# Ensure tile dimensions are divisible by VAE_SPATIAL_FACTOR
tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR
tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR

if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR:
continue

# Adjust end positions
h_end = h_start + tile_h
w_end = w_start + tile_w

# Extract tile
tile = video[:, :, :, h_start:h_end, w_start:w_end]

# Encode tile
encoded_tile = vae(tile)

# Get actual encoded dimensions
_, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape

# Calculate output positions
out_h_start = h_start // VAE_SPATIAL_FACTOR
out_w_start = w_start // VAE_SPATIAL_FACTOR
tile_out_height = tile_h // VAE_SPATIAL_FACTOR
tile_out_width = tile_w // VAE_SPATIAL_FACTOR
out_h_end = min(out_h_start + tile_out_height, output_height)
out_w_end = min(out_w_start + tile_out_width, output_width)

# Trim encoded tile if necessary
actual_tile_h = out_h_end - out_h_start
actual_tile_w = out_w_end - out_w_start
encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]

# Create blending mask with linear feathering at edges
mask = torch.ones(
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
device=device,
dtype=dtype,
entry = {
"h_start": h_start,
"h_end": h_end,
"w_start": w_start,
"w_end": w_end,
"h_pos": h_pos,
"w_pos": w_pos,
"out_h_start": out_h_start,
"out_h_end": out_h_end,
"out_w_start": out_w_start,
"out_w_end": out_w_end,
"actual_tile_h": actual_tile_h,
"actual_tile_w": actual_tile_w,
}
tiles_by_shape.setdefault((tile_h, tile_w), []).append(entry)

# Phase 2: encode tiles in mini-batches (same shape per batch) and blend into output.
for entries in tiles_by_shape.values():
for i in range(0, len(entries), tile_batch_size):
mini_batch = entries[i : i + tile_batch_size]
tiles = torch.cat(
[video[:, :, :, e["h_start"]:e["h_end"], e["w_start"]:e["w_end"]] for e in mini_batch],
dim=0,
)
encoded_batch = vae(tiles)
chunks = encoded_batch.chunk(len(mini_batch), dim=0)

for entry, encoded_tile in zip(mini_batch, chunks, strict=True):
_, _, tile_out_frames, _, _ = encoded_tile.shape
actual_tile_h = entry["actual_tile_h"]
actual_tile_w = entry["actual_tile_w"]
encoded_tile_trimmed = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w]

# Apply feathering at edges (linear blend in overlap regions)
# Left edge
if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1)
mask = torch.ones(
(1, 1, tile_out_frames, actual_tile_h, actual_tile_w),
device=device,
dtype=dtype,
)

# Right edge (bottom in height dimension)
if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1)
if entry["h_pos"] > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
mask[:, :, :, :overlap_out_h, :] *= fade_in_h.view(1, 1, 1, -1, 1) # type: ignore[union-attr]

# Top edge (left in width dimension)
if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1)
if entry["h_end"] < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h:
mask[:, :, :, -overlap_out_h:, :] *= fade_out_h.view(1, 1, 1, -1, 1) # type: ignore[union-attr]

# Bottom edge (right in width dimension)
if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1]
mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1)
if entry["w_pos"] > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
mask[:, :, :, :, :overlap_out_w] *= fade_in_w.view(1, 1, 1, 1, -1) # type: ignore[union-attr]

# Accumulate weighted results
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask
if entry["w_end"] < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w:
mask[:, :, :, :, -overlap_out_w:] *= fade_out_w.view(1, 1, 1, 1, -1) # type: ignore[union-attr]

out_h_start = entry["out_h_start"]
out_h_end = entry["out_h_end"]
out_w_start = entry["out_w_start"]
out_w_end = entry["out_w_end"]
output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile_trimmed * mask
weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask

# Normalize by weights (avoid division by zero)
output = output / (weights + 1e-8)
Expand Down Expand Up @@ -1022,6 +1085,10 @@ def main( # noqa: PLR0913
default=False,
help="Enable VAE tiling for larger video resolutions",
),
tile_batch_size: int = typer.Option(
default=1,
help="Number of same-shape tiles per VAE forward pass (1 = sequential). Only used when --vae-tiling is set.",
),
reshape_mode: str = typer.Option(
default="center",
help="How to crop videos: 'center' or 'random'",
Expand Down Expand Up @@ -1090,6 +1157,7 @@ def main( # noqa: PLR0913
batch_size=batch_size,
device=device,
vae_tiling=vae_tiling,
tile_batch_size=tile_batch_size,
with_audio=with_audio,
audio_output_dir=audio_output_dir,
overwrite=overwrite,
Expand Down
4 changes: 4 additions & 0 deletions packages/ltx-trainer/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "scripts"))
Loading