Skip to content

perf(trainer): add tile_batch_size, cut VAE calls up to 15x in tiled encoding#215

Open
Vittoria Lanzo (VittoriaLanzo) wants to merge 3 commits into
Lightricks:mainfrom
VittoriaLanzo:perf/tile-batch-encode
Open

perf(trainer): add tile_batch_size, cut VAE calls up to 15x in tiled encoding#215
Vittoria Lanzo (VittoriaLanzo) wants to merge 3 commits into
Lightricks:mainfrom
VittoriaLanzo:perf/tile-batch-encode

Conversation

@VittoriaLanzo
Copy link
Copy Markdown

@VittoriaLanzo Vittoria Lanzo (VittoriaLanzo) commented May 14, 2026

The problem

When --vae-tiling is enabled, tiled_encode_video calls the VAE encoder once per spatial tile in a sequential loop — 15 calls at 1920×1080, 60 at 3840×2160. Each call is a full encoder round-trip. The total cost grows linearly with tile count regardless of available GPU parallelism.

A secondary inefficiency: four feathering torch.linspace vectors (fade_in_h, fade_out_h, fade_in_w, fade_out_w) are reallocated inside the tile loop on every iteration despite depending only on tile_overlap and VAE_SPATIAL_FACTOR, which are constant for the duration of the call. At 3840×2160 this is up to 240 redundant small-tensor allocations per tiled_encode_video call.

Changes

Mini-batch VAE forward passes

A new parameter tile_batch_size: int = 1 is added to tiled_encode_video, encode_video, compute_latents, and the CLI (--tile-batch-size).

Encoding uses a two-phase approach:

Phase 1 (collect) — iterate tile positions in the original row-major order and record boundary metadata. Tiles are grouped by pixel shape (tile_h, tile_w). Shape grouping is required because torch.cat requires identical non-batch dimensions: edge tiles are floor-aligned to VAE_SPATIAL_FACTOR=32 and may have different pixel dimensions from interior tiles.

Phase 2 (encode + blend) — for each shape group, split its tile list into sub-lists of at most tile_batch_size. Stack crops with torch.cat(dim=0), call vae once, split with .chunk(len(sub_list), dim=0) — not .chunk(tile_batch_size) so the final partial sub-list is never over-split — then run the existing blending accumulation unchanged per chunk.

tile_batch_size=1 (default) is API-backward-compatible and numerically equivalent to the pre-patch code within float32 rounding tolerance.

Loop-invariant linspace hoisting

The four linspace allocations move to immediately after overlap_out_h/overlap_out_w are computed, before the tile loop. All guard conditions and .view() calls remain in-loop. Eliminates up to 4 × (tile count) redundant small-tensor allocations per call.

VAE call count reduction

Verified by running a call-counting wrapper on the patched code at default settings (tile_size=512, tile_overlap=128). "Minimum calls" requires tile_batch_size >= largest group size and equals the number of distinct shape groups.

Resolution Current calls tile_batch_size for min Min calls Call reduction
896x896 4 4 1 4x
1280x720 6 3 2 3x
1920x1080 15 8 4 3.75x
2560x1440 28 18 4 7x
3840x2160 60 45 4 15x

At every resolution there are at most four shape groups (at most 2 distinct heights x 2 distinct widths after edge rounding). No GPU benchmark was run; wall-time improvement depends on hardware and the per-call overhead of the real LTX-2 encoder.

Shape group breakdown
896x896:   (512x512)x4
1280x720:  (512x512)x3  (320x512)x3
1920x1080: (512x512)x8  (288x512)x4  (512x384)x2  (288x384)x1
2560x1440: (512x512)x18 (288x512)x6  (512x256)x3  (288x256)x1
3840x2160: (512x512)x45 (224x512)x9  (512x384)x5  (224x384)x1

Dimensions are (tile_height x tile_width).

Benchmark script

Run this on your hardware to measure actual wall-time speedup with the real LTX-2 encoder:

"""Benchmark tiled_encode_video: call count and wall time.

Usage (from packages/ltx-trainer):
    uv run python ../../benchmark_tile_batch.py \
        --checkpoint /path/to/ltx2.safetensors \
        --resolution 1920x1080 \
        --tile-batch-size 8
"""
import argparse
import time
import torch
from ltx_trainer.model_loader import load_video_vae_encoder
from scripts.process_videos import tiled_encode_video


def count_calls(vae, video, batch_size):
    calls = [0]
    orig_forward = vae.__class__.forward

    def patched(self, x):
        calls[0] += 1
        return orig_forward(self, x)

    vae.__class__.forward = patched
    try:
        tiled_encode_video(vae, video, tile_batch_size=batch_size)
    finally:
        vae.__class__.forward = orig_forward
    return calls[0]


def bench(vae, video, batch_size, warmup=2, reps=5):
    for _ in range(warmup):
        tiled_encode_video(vae, video, tile_batch_size=batch_size)
    if video.is_cuda:
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(reps):
        tiled_encode_video(vae, video, tile_batch_size=batch_size)
    if video.is_cuda:
        torch.cuda.synchronize()
    return (time.perf_counter() - t0) / reps * 1000  # ms


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", required=True)
    parser.add_argument("--resolution", default="1920x1080")
    parser.add_argument("--frames", type=int, default=9)
    parser.add_argument("--tile-batch-size", type=int, default=8)
    parser.add_argument("--device", default="cuda")
    args = parser.parse_args()

    W, H = map(int, args.resolution.split("x"))
    device = torch.device(args.device)

    print(f"Loading VAE from {args.checkpoint} ...")
    vae = load_video_vae_encoder(args.checkpoint, device=device, dtype=torch.bfloat16)
    vae.eval()

    video = torch.rand(1, 3, args.frames, H, W, device=device, dtype=torch.bfloat16)

    n_seq = count_calls(vae, video, batch_size=1)
    n_bat = count_calls(vae, video, batch_size=args.tile_batch_size)
    ms_seq = bench(vae, video, batch_size=1)
    ms_bat = bench(vae, video, batch_size=args.tile_batch_size)
    speedup = ms_seq / ms_bat if ms_bat > 0 else float("inf")

    print(f"\nResolution      {args.resolution}  frames={args.frames}")
    print(f"tile_batch_size {args.tile_batch_size}")
    print(f"VAE calls       {n_seq} -> {n_bat}  ({n_seq / n_bat:.1f}x fewer)")
    print(f"Wall time       {ms_seq:.1f} ms -> {ms_bat:.1f} ms  ({speedup:.2f}x speedup)")

Design choices

Two-phase design. Separating metadata collection (Phase 1, pure Python) from tensor operations (Phase 2) keeps each phase auditable in isolation and avoids deferring partial blending state mid-loop.

.chunk(len(sub_list)) not .chunk(tile_batch_size). The final sub-list in a group may be smaller than tile_batch_size. Using .chunk(tile_batch_size) would produce extra empty tensors that cause silent index errors in the blending loop.

Validation before fast-path. tile_batch_size < 1 is rejected before the single-tile early return so the error fires regardless of video size, consistent with how tile_size and tile_overlap are validated unconditionally.

pytest in dev group. uv run pytest activates only the dev dependency group; a custom group would require --group <name> on every invocation.

Correctness

  • Cross-shape tiles are never stacked; the shape-group dict enforces this before every torch.cat.
  • Blending accumulation is unchanged.
  • For tile_batch_size > 1 to produce the same result as tile_batch_size=1, the VAE must produce the same output for a tile regardless of batch co-members. The parity tests confirm this for the test mock; verify on the real encoder before increasing tile_batch_size.

Usage

# Default - identical to original behaviour
encode_video(vae, video, use_tiling=True)

# Good starting point on a 24 GB GPU at 1920x1080
encode_video(vae, video, use_tiling=True, tile_batch_size=8)
python scripts/process_videos.py dataset.csv \
    --resolution-buckets 1920x1080x25 \
    --output-dir ./latents \
    --model-path /path/to/ltx2.safetensors \
    --vae-tiling \
    --tile-batch-size 8

tile_batch_size=N increases peak VRAM per forward pass by approximately Nx the single-tile activation cost. tile_batch_size=1 (default) is always safe. At 3840×2160 the interior shape group has 45 tiles — tile_batch_size=45 would batch all of them at once (peak activation ~45× per-tile cost); start at a low value and increase based on available VRAM.

Tests

New: packages/ltx-trainer/tests/test_tiled_encode.py - 12 tests, all fail on pre-patch code.

Test Covers
test_output_shape Correct latent shape
test_tile_batch_size_produces_identical_output Parity: batched == sequential (896x896, one shape group)
test_tile_batch_size_identical_output_mixed_shapes Parity on 960x960 (four shape groups)
test_tile_batch_size_larger_than_tile_count tile_batch_size=100 on 4 tiles - partial last sub-list
test_tile_batch_size_zero_raises 0 -> ValueError before fast-path
test_tile_batch_size_negative_raises -1 -> ValueError (guard is < 1, not == 0)
test_encode_video_threads_tile_batch_size Parameter live through encode_video
test_fast_path_single_tile Pre-existing fast-path not broken
test_mixed_shape_group_call_count Exact call counts: 4 batched vs 9 sequential on 960x960
test_tile_size_not_divisible_raises Pre-existing guard
test_tile_overlap_not_divisible_raises Pre-existing guard
test_tile_overlap_gte_tile_size_raises Pre-existing guard

test_tile_batch_size_zero_raises and test_tile_batch_size_negative_raises are separate because changing the guard from < 1 to == 0 would pass the zero test but silently produce all-NaN output for negative values (empty range step). MockVAE uses Conv3d(stride=(8,32,32)) so its temporal output matches 1 + (F-1) // VAE_TEMPORAL_FACTOR, preventing the shape mismatch a stride-1 mock would cause.

pyproject.toml: adds pytest>=8.0 / pytest-cov>=5.0 to [dependency-groups] dev; corrects ruff target-version from "1.1.3" (the project version string) to "py310"; adds [tool.pytest.ini_options]. These changes are bundled here because they are prerequisites for the test suite this PR adds: without the dev deps uv run pytest fails with "No module named pytest", and without the corrected target-version ruff cannot parse pyproject.toml and ruff check always errors. Happy to split into a separate preparatory PR if preferred.

Checklist

  • Tests added (12 tests, all fail on pre-patch code)
  • No new runtime dependencies
  • Default tile_batch_size=1 is backward-compatible
  • ruff check passes
  • uv run pytest passes

Agentic workflow

This contribution was produced with a multi-agent pipeline under my direction and supervision. I approved each stage at human checkpoints; the pipeline handled planning, implementation, testing, and adversarial review.

Stage 0 — Performance assessment (identifies the tiled-encode hotspot):

flowchart TD
    S([STARTUP\nroster check · pre-flight]):::green --> O[ORCHESTRATOR\nmain thread · routing only]:::blue
    O --> R[RECON AGENT\nstatic analysis · 5 passes · hotspot index]:::blue
    R -->|HOTSPOT_INDEX| D[DISPATCHER\nmechanical signal routing]:::yellow
    D -->|ROUTING_MANIFEST| C[COMPLEXITY AGENT\nnested loops · O n2]:::red
    D -->|ROUTING_MANIFEST| M[MEMORY AGENT\nalloc in loops · GC pressure]:::red
    D -->|ROUTING_MANIFEST| IO[IO AGENT\nN+1 · blocking async · locks]:::red
    D -->|ROUTING_MANIFEST| DS[DATASTRUCTURES AGENT\nwrong DS · linear search]:::red
    D -->|ROUTING_MANIFEST| CA[CACHE AGENT\nredundant compute · regex in loop]:::red
    C -->|FINDINGS| SC[SCORING COLLECTOR\nmerge · deduplicate · priority_score · top 15]:::yellow
    M -->|FINDINGS| SC
    IO -->|FINDINGS| SC
    DS -->|FINDINGS| SC
    CA -->|FINDINGS| SC
    SC -->|SCORED_FINDINGS| PQ[PRE-QUALIFICATION GATE\ngit log · gh pr · grep\neliminate owned or intentional findings]:::purple
    PQ -->|cleared_findings| TR[TOT ROOT\nBranch A: correctness bugs\nBranch B: high-value speed\nBranch C: risk / low-priority]:::purple
    TR -->|branches A · B · C| AS[ASSESSMENT SYNTHESIZER\nread code · confirm or discard each finding]:::purple
    AS -->|ASSESSMENT_REPORT| HR([HUMAN REVIEW GATE\npipeline halts · human decides]):::green

    classDef green fill:#1a4a1a,stroke:#4CAF50,color:#4CAF50
    classDef blue fill:#1a2a4a,stroke:#5b9bd5,color:#9cc4e8
    classDef yellow fill:#3a3000,stroke:#d4a017,color:#d4c87a
    classDef red fill:#3a0000,stroke:#cc3333,color:#ff8888
    classDef purple fill:#2a1a4a,stroke:#9966cc,color:#cc99ff
Loading

Stage 1-3 — Contribution pipeline (plans, implements, reviews, delivers):

flowchart TD
    subgraph L1["Layer 1 — Planning (8 plan-review rounds)"]
        PE[pattern-extractor] --> CE[coverage-engineer\nassessment mode]
        CE --> CP[contribution-planner]
        CP --> AR[architect-reviewer]
        AR -->|OBJECTIONS_RAISED| CP
        AR -->|NO_OBJECTIONS| impl
    end

    subgraph L2["Layer 2 — Implementation + Adversarial Loop"]
        impl([H.I.T.L.]) --> CA1[code-author\nimplementation pass]
        CA1 --> CA2[code-author\nregression tests]
        CA2 --> RE[readability-editor]
        RE --> VT[verbosity-trimmer]
        VT --> SL[scope-linter]
        SL --> SR[senior-reviewer\nATTACK]
        SL --> RT[red-team\nATTACK]
        SR --> SE[senior-engineer]
        RT --> SE
        SE --> TA[test-author]
        TA --> HM[cold-reviewer\nRE_REVIEW]
        TA --> RT2[red-team\nRE_REVIEW]
        HM --> SE2[senior-engineer\nrevision]
        RT2 --> SE2
        SE2 --> MG[merge-gate]
        MG -->|APPROVE| pr
        MG -->|REVISE| SE
    end

    subgraph L3["Layer 3 — PR Delivery"]
        pr([H.I.T.L]) --> PW[pr-writer]
        PW --> PRA[pr-adversary]
        PRA -->|REVISIONS_REQUESTED| PW
        PRA -->|PR_READY| LA[license-auditor]
        LA --> done([H.I.T.L])
    end
Loading

- Add tile_batch_size: int = 1 to tiled_encode_video, encode_video,
  compute_latents, and main() CLI (--tile-batch-size)
- Two-phase approach: Phase 1 collects tile metadata grouped by pixel
  shape (tile_h, tile_w); Phase 2 stacks same-shape tiles, calls vae
  once per mini-batch, splits with .chunk(len(sub_list), dim=0)
- Hoist four loop-invariant linspace feathering vectors above tile loop
- Add 12 tests in packages/ltx-trainer/tests/test_tiled_encode.py
- Fix ruff target-version from "1.1.3" to "py310" in pyproject.toml
@VittoriaLanzo Vittoria Lanzo (VittoriaLanzo) changed the title perf(trainer): batch tiled VAE encodes and hoist linspace allocations perf(trainer): add tile_batch_size, cut VAE calls up to 15x in tiled encoding May 14, 2026
@VittoriaLanzo Vittoria Lanzo (VittoriaLanzo) marked this pull request as ready for review May 14, 2026 11:52
@VittoriaLanzo
Copy link
Copy Markdown
Author

Vittoria Lanzo (VittoriaLanzo) commented May 14, 2026

Michael Kupchick (@michaellightricks) tiled encoding at 4K currently issues 60 sequential VAE calls per video. This PR adds --tile-batch-size which batches same-shape tiles into a single forward pass, bringing it down to 4 calls at 3840x2160 (one per shape group). Default is 1 so nothing changes unless you opt in. 12 tests included.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant