Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
loras: list[LoraPathStrengthAndSDOps],
device: str = device,
quantization: QuantizationPolicy | None = None,
cache_models: bool = False,
):
self.device = device
self.dtype = torch.bfloat16
Expand All @@ -67,6 +68,7 @@ def __init__(
spatial_upsampler_path=spatial_upsampler_path,
loras=loras,
quantization=quantization,
cache_models=cache_models,
)

self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras(
Expand Down
82 changes: 68 additions & 14 deletions packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,14 @@ class ModelLedger:
:class:`~ltx_core.loader.registry.Registry` to load weights from the checkpoint,
instantiates the model with the configured ``dtype``, and moves it to ``self.device``.
.. note::
Models are **not cached**. Each call to a model method creates a new instance.
Models are **not cached** by default. Each call to a model method creates a new instance.
Callers are responsible for storing references to models they wish to reuse
and for freeing GPU memory (e.g. by deleting references and calling
``torch.cuda.empty_cache()``).

Set ``cache_models=True`` to enable model instance caching, which returns the same
instance on subsequent calls instead of rebuilding. Use :meth:`clear_model_cache`
to free cached models when done.
### Constructor parameters
dtype:
Torch dtype used when constructing all models (e.g. ``torch.bfloat16``).
Expand All @@ -81,6 +85,10 @@ class ModelLedger:
quantization:
Optional :class:`QuantizationPolicy` controlling how transformer weights
are stored and how matmul is executed. Defaults to None, which means no quantization.
cache_models:
If ``True``, caches model instances so subsequent calls return the same instance
instead of rebuilding. Useful for repeated inference to avoid weight reloading.
Use :meth:`clear_model_cache` to free cached models when done.
### Creating Variants
Use :meth:`with_loras` to create a new ``ModelLedger`` instance that includes
additional LoRA configurations while sharing the same registry for weight caching.
Expand All @@ -96,6 +104,7 @@ def __init__(
loras: LoraPathStrengthAndSDOps | None = None,
registry: Registry | None = None,
quantization: QuantizationPolicy | None = None,
cache_models: bool = False,
):
self.dtype = dtype
self.device = device
Expand All @@ -105,6 +114,8 @@ def __init__(
self.loras = loras or ()
self.registry = registry or DummyRegistry()
self.quantization = quantization
self.cache_models = cache_models
self._model_cache: dict[str, torch.nn.Module] = {}
self.build_model_builders()

def build_model_builders(self) -> None:
Expand Down Expand Up @@ -181,16 +192,19 @@ def with_loras(self, loras: LoraPathStrengthAndSDOps) -> "ModelLedger":
loras=(*self.loras, *loras),
registry=self.registry,
quantization=self.quantization,
cache_models=self.cache_models,
)

def transformer(self) -> X0Model:
if not hasattr(self, "transformer_builder"):
raise ValueError(
"Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor."
)
if self.cache_models and "transformer" in self._model_cache:
return self._model_cache["transformer"]

if self.quantization is None:
return (
model = (
X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype))
.to(self.device)
.eval()
Expand All @@ -207,51 +221,91 @@ def transformer(self) -> X0Model:
module_ops=(*self.transformer_builder.module_ops, *self.quantization.module_ops),
model_sd_ops=sd_ops,
)
return X0Model(builder.build(device=self._target_device())).to(self.device).eval()
model = X0Model(builder.build(device=self._target_device())).to(self.device).eval()

if self.cache_models:
self._model_cache["transformer"] = model
return model

def video_decoder(self) -> VideoDecoder:
if not hasattr(self, "vae_decoder_builder"):
raise ValueError(
"Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
)

return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models and "video_decoder" in self._model_cache:
return self._model_cache["video_decoder"]
model = self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models:
self._model_cache["video_decoder"] = model
return model

def video_encoder(self) -> VideoEncoder:
if not hasattr(self, "vae_encoder_builder"):
raise ValueError(
"Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
)

return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models and "video_encoder" in self._model_cache:
return self._model_cache["video_encoder"]
model = self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models:
self._model_cache["video_encoder"] = model
return model

def text_encoder(self) -> AVGemmaTextEncoderModel:
if not hasattr(self, "text_encoder_builder"):
raise ValueError(
"Text encoder not initialized. Please provide a checkpoint path and gemma root path to the "
"ModelLedger constructor."
)

return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models and "text_encoder" in self._model_cache:
return self._model_cache["text_encoder"]
model = self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models:
self._model_cache["text_encoder"] = model
return model

def audio_decoder(self) -> AudioDecoder:
if not hasattr(self, "audio_decoder_builder"):
raise ValueError(
"Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
)

return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models and "audio_decoder" in self._model_cache:
return self._model_cache["audio_decoder"]
model = self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models:
self._model_cache["audio_decoder"] = model
return model

def vocoder(self) -> Vocoder:
if not hasattr(self, "vocoder_builder"):
raise ValueError(
"Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor."
)

return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models and "vocoder" in self._model_cache:
return self._model_cache["vocoder"]
model = self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models:
self._model_cache["vocoder"] = model
return model

def spatial_upsampler(self) -> LatentUpsampler:
if not hasattr(self, "upsampler_builder"):
raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.")
if self.cache_models and "spatial_upsampler" in self._model_cache:
return self._model_cache["spatial_upsampler"]
model = self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
if self.cache_models:
self._model_cache["spatial_upsampler"] = model
return model

return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval()
def clear_model_cache(self, model_name: str | None = None) -> None:
"""Clear cached model instances to free GPU memory.

Args:
model_name: If provided, only clear this specific model (e.g. "transformer",
"video_encoder", "text_encoder"). Otherwise clear all cached models.
"""
if model_name is not None:
self._model_cache.pop(model_name, None)
else:
self._model_cache.clear()