perf(trainer): add tile_batch_size, cut VAE calls up to 15x in tiled encoding#215
Open
Vittoria Lanzo (VittoriaLanzo) wants to merge 3 commits into
Open
Conversation
- Add tile_batch_size: int = 1 to tiled_encode_video, encode_video, compute_latents, and main() CLI (--tile-batch-size) - Two-phase approach: Phase 1 collects tile metadata grouped by pixel shape (tile_h, tile_w); Phase 2 stacks same-shape tiles, calls vae once per mini-batch, splits with .chunk(len(sub_list), dim=0) - Hoist four loop-invariant linspace feathering vectors above tile loop - Add 12 tests in packages/ltx-trainer/tests/test_tiled_encode.py - Fix ruff target-version from "1.1.3" to "py310" in pyproject.toml
Author
|
Michael Kupchick (@michaellightricks) tiled encoding at 4K currently issues 60 sequential VAE calls per video. This PR adds |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The problem
When
--vae-tilingis enabled,tiled_encode_videocalls the VAE encoder once per spatial tile in a sequential loop — 15 calls at 1920×1080, 60 at 3840×2160. Each call is a full encoder round-trip. The total cost grows linearly with tile count regardless of available GPU parallelism.A secondary inefficiency: four feathering
torch.linspacevectors (fade_in_h,fade_out_h,fade_in_w,fade_out_w) are reallocated inside the tile loop on every iteration despite depending only ontile_overlapandVAE_SPATIAL_FACTOR, which are constant for the duration of the call. At 3840×2160 this is up to 240 redundant small-tensor allocations pertiled_encode_videocall.Changes
Mini-batch VAE forward passes
A new parameter
tile_batch_size: int = 1is added totiled_encode_video,encode_video,compute_latents, and the CLI (--tile-batch-size).Encoding uses a two-phase approach:
Phase 1 (collect) — iterate tile positions in the original row-major order and record boundary metadata. Tiles are grouped by pixel shape
(tile_h, tile_w). Shape grouping is required becausetorch.catrequires identical non-batch dimensions: edge tiles are floor-aligned toVAE_SPATIAL_FACTOR=32and may have different pixel dimensions from interior tiles.Phase 2 (encode + blend) — for each shape group, split its tile list into sub-lists of at most
tile_batch_size. Stack crops withtorch.cat(dim=0), callvaeonce, split with.chunk(len(sub_list), dim=0)— not.chunk(tile_batch_size)so the final partial sub-list is never over-split — then run the existing blending accumulation unchanged per chunk.tile_batch_size=1(default) is API-backward-compatible and numerically equivalent to the pre-patch code within float32 rounding tolerance.Loop-invariant linspace hoisting
The four linspace allocations move to immediately after
overlap_out_h/overlap_out_ware computed, before the tile loop. All guard conditions and.view()calls remain in-loop. Eliminates up to 4 × (tile count) redundant small-tensor allocations per call.VAE call count reduction
Verified by running a call-counting wrapper on the patched code at default settings (
tile_size=512,tile_overlap=128). "Minimum calls" requirestile_batch_size >= largest group sizeand equals the number of distinct shape groups.tile_batch_sizefor minAt every resolution there are at most four shape groups (at most 2 distinct heights x 2 distinct widths after edge rounding). No GPU benchmark was run; wall-time improvement depends on hardware and the per-call overhead of the real LTX-2 encoder.
Shape group breakdown
Dimensions are (tile_height x tile_width).
Benchmark script
Run this on your hardware to measure actual wall-time speedup with the real LTX-2 encoder:
Design choices
Two-phase design. Separating metadata collection (Phase 1, pure Python) from tensor operations (Phase 2) keeps each phase auditable in isolation and avoids deferring partial blending state mid-loop.
.chunk(len(sub_list))not.chunk(tile_batch_size). The final sub-list in a group may be smaller thantile_batch_size. Using.chunk(tile_batch_size)would produce extra empty tensors that cause silent index errors in the blending loop.Validation before fast-path.
tile_batch_size < 1is rejected before the single-tile early return so the error fires regardless of video size, consistent with howtile_sizeandtile_overlapare validated unconditionally.pytestindevgroup.uv run pytestactivates only thedevdependency group; a custom group would require--group <name>on every invocation.Correctness
torch.cat.tile_batch_size > 1to produce the same result astile_batch_size=1, the VAE must produce the same output for a tile regardless of batch co-members. The parity tests confirm this for the test mock; verify on the real encoder before increasingtile_batch_size.Usage
python scripts/process_videos.py dataset.csv \ --resolution-buckets 1920x1080x25 \ --output-dir ./latents \ --model-path /path/to/ltx2.safetensors \ --vae-tiling \ --tile-batch-size 8tile_batch_size=Nincreases peak VRAM per forward pass by approximately Nx the single-tile activation cost.tile_batch_size=1(default) is always safe. At 3840×2160 the interior shape group has 45 tiles —tile_batch_size=45would batch all of them at once (peak activation ~45× per-tile cost); start at a low value and increase based on available VRAM.Tests
New:
packages/ltx-trainer/tests/test_tiled_encode.py- 12 tests, all fail on pre-patch code.test_output_shapetest_tile_batch_size_produces_identical_outputtest_tile_batch_size_identical_output_mixed_shapestest_tile_batch_size_larger_than_tile_counttile_batch_size=100on 4 tiles - partial last sub-listtest_tile_batch_size_zero_raises0->ValueErrorbefore fast-pathtest_tile_batch_size_negative_raises-1->ValueError(guard is< 1, not== 0)test_encode_video_threads_tile_batch_sizeencode_videotest_fast_path_single_tiletest_mixed_shape_group_call_counttest_tile_size_not_divisible_raisestest_tile_overlap_not_divisible_raisestest_tile_overlap_gte_tile_size_raisestest_tile_batch_size_zero_raisesandtest_tile_batch_size_negative_raisesare separate because changing the guard from< 1to== 0would pass the zero test but silently produce all-NaN output for negative values (emptyrangestep).MockVAEusesConv3d(stride=(8,32,32))so its temporal output matches1 + (F-1) // VAE_TEMPORAL_FACTOR, preventing the shape mismatch a stride-1 mock would cause.pyproject.toml: addspytest>=8.0/pytest-cov>=5.0to[dependency-groups] dev; correctsrufftarget-versionfrom"1.1.3"(the project version string) to"py310"; adds[tool.pytest.ini_options]. These changes are bundled here because they are prerequisites for the test suite this PR adds: without the dev depsuv run pytestfails with "No module named pytest", and without the correctedtarget-versionruff cannot parsepyproject.tomlandruff checkalways errors. Happy to split into a separate preparatory PR if preferred.Checklist
tile_batch_size=1is backward-compatibleruff checkpassesuv run pytestpassesAgentic workflow
This contribution was produced with a multi-agent pipeline under my direction and supervision. I approved each stage at human checkpoints; the pipeline handled planning, implementation, testing, and adversarial review.
Stage 0 — Performance assessment (identifies the tiled-encode hotspot):
flowchart TD S([STARTUP\nroster check · pre-flight]):::green --> O[ORCHESTRATOR\nmain thread · routing only]:::blue O --> R[RECON AGENT\nstatic analysis · 5 passes · hotspot index]:::blue R -->|HOTSPOT_INDEX| D[DISPATCHER\nmechanical signal routing]:::yellow D -->|ROUTING_MANIFEST| C[COMPLEXITY AGENT\nnested loops · O n2]:::red D -->|ROUTING_MANIFEST| M[MEMORY AGENT\nalloc in loops · GC pressure]:::red D -->|ROUTING_MANIFEST| IO[IO AGENT\nN+1 · blocking async · locks]:::red D -->|ROUTING_MANIFEST| DS[DATASTRUCTURES AGENT\nwrong DS · linear search]:::red D -->|ROUTING_MANIFEST| CA[CACHE AGENT\nredundant compute · regex in loop]:::red C -->|FINDINGS| SC[SCORING COLLECTOR\nmerge · deduplicate · priority_score · top 15]:::yellow M -->|FINDINGS| SC IO -->|FINDINGS| SC DS -->|FINDINGS| SC CA -->|FINDINGS| SC SC -->|SCORED_FINDINGS| PQ[PRE-QUALIFICATION GATE\ngit log · gh pr · grep\neliminate owned or intentional findings]:::purple PQ -->|cleared_findings| TR[TOT ROOT\nBranch A: correctness bugs\nBranch B: high-value speed\nBranch C: risk / low-priority]:::purple TR -->|branches A · B · C| AS[ASSESSMENT SYNTHESIZER\nread code · confirm or discard each finding]:::purple AS -->|ASSESSMENT_REPORT| HR([HUMAN REVIEW GATE\npipeline halts · human decides]):::green classDef green fill:#1a4a1a,stroke:#4CAF50,color:#4CAF50 classDef blue fill:#1a2a4a,stroke:#5b9bd5,color:#9cc4e8 classDef yellow fill:#3a3000,stroke:#d4a017,color:#d4c87a classDef red fill:#3a0000,stroke:#cc3333,color:#ff8888 classDef purple fill:#2a1a4a,stroke:#9966cc,color:#cc99ffStage 1-3 — Contribution pipeline (plans, implements, reviews, delivers):
flowchart TD subgraph L1["Layer 1 — Planning (8 plan-review rounds)"] PE[pattern-extractor] --> CE[coverage-engineer\nassessment mode] CE --> CP[contribution-planner] CP --> AR[architect-reviewer] AR -->|OBJECTIONS_RAISED| CP AR -->|NO_OBJECTIONS| impl end subgraph L2["Layer 2 — Implementation + Adversarial Loop"] impl([H.I.T.L.]) --> CA1[code-author\nimplementation pass] CA1 --> CA2[code-author\nregression tests] CA2 --> RE[readability-editor] RE --> VT[verbosity-trimmer] VT --> SL[scope-linter] SL --> SR[senior-reviewer\nATTACK] SL --> RT[red-team\nATTACK] SR --> SE[senior-engineer] RT --> SE SE --> TA[test-author] TA --> HM[cold-reviewer\nRE_REVIEW] TA --> RT2[red-team\nRE_REVIEW] HM --> SE2[senior-engineer\nrevision] RT2 --> SE2 SE2 --> MG[merge-gate] MG -->|APPROVE| pr MG -->|REVISE| SE end subgraph L3["Layer 3 — PR Delivery"] pr([H.I.T.L]) --> PW[pr-writer] PW --> PRA[pr-adversary] PRA -->|REVISIONS_REQUESTED| PW PRA -->|PR_READY| LA[license-auditor] LA --> done([H.I.T.L]) end