diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py index 1369454e..8af13630 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py @@ -56,6 +56,7 @@ def __init__( loras: list[LoraPathStrengthAndSDOps], device: str = device, quantization: QuantizationPolicy | None = None, + cache_models: bool = False, ): self.device = device self.dtype = torch.bfloat16 @@ -67,6 +68,7 @@ def __init__( spatial_upsampler_path=spatial_upsampler_path, loras=loras, quantization=quantization, + cache_models=cache_models, ) self.stage_2_model_ledger = self.stage_1_model_ledger.with_loras( diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py index d5e18c78..d39c0c01 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/model_ledger.py @@ -51,10 +51,14 @@ class ModelLedger: :class:`~ltx_core.loader.registry.Registry` to load weights from the checkpoint, instantiates the model with the configured ``dtype``, and moves it to ``self.device``. .. note:: - Models are **not cached**. Each call to a model method creates a new instance. + Models are **not cached** by default. Each call to a model method creates a new instance. Callers are responsible for storing references to models they wish to reuse and for freeing GPU memory (e.g. by deleting references and calling ``torch.cuda.empty_cache()``). + + Set ``cache_models=True`` to enable model instance caching, which returns the same + instance on subsequent calls instead of rebuilding. Use :meth:`clear_model_cache` + to free cached models when done. ### Constructor parameters dtype: Torch dtype used when constructing all models (e.g. ``torch.bfloat16``). @@ -81,6 +85,10 @@ class ModelLedger: quantization: Optional :class:`QuantizationPolicy` controlling how transformer weights are stored and how matmul is executed. Defaults to None, which means no quantization. + cache_models: + If ``True``, caches model instances so subsequent calls return the same instance + instead of rebuilding. Useful for repeated inference to avoid weight reloading. + Use :meth:`clear_model_cache` to free cached models when done. ### Creating Variants Use :meth:`with_loras` to create a new ``ModelLedger`` instance that includes additional LoRA configurations while sharing the same registry for weight caching. @@ -96,6 +104,7 @@ def __init__( loras: LoraPathStrengthAndSDOps | None = None, registry: Registry | None = None, quantization: QuantizationPolicy | None = None, + cache_models: bool = False, ): self.dtype = dtype self.device = device @@ -105,6 +114,8 @@ def __init__( self.loras = loras or () self.registry = registry or DummyRegistry() self.quantization = quantization + self.cache_models = cache_models + self._model_cache: dict[str, torch.nn.Module] = {} self.build_model_builders() def build_model_builders(self) -> None: @@ -181,6 +192,7 @@ def with_loras(self, loras: LoraPathStrengthAndSDOps) -> "ModelLedger": loras=(*self.loras, *loras), registry=self.registry, quantization=self.quantization, + cache_models=self.cache_models, ) def transformer(self) -> X0Model: @@ -188,9 +200,11 @@ def transformer(self) -> X0Model: raise ValueError( "Transformer not initialized. Please provide a checkpoint path to the ModelLedger constructor." ) + if self.cache_models and "transformer" in self._model_cache: + return self._model_cache["transformer"] if self.quantization is None: - return ( + model = ( X0Model(self.transformer_builder.build(device=self._target_device(), dtype=self.dtype)) .to(self.device) .eval() @@ -207,23 +221,35 @@ def transformer(self) -> X0Model: module_ops=(*self.transformer_builder.module_ops, *self.quantization.module_ops), model_sd_ops=sd_ops, ) - return X0Model(builder.build(device=self._target_device())).to(self.device).eval() + model = X0Model(builder.build(device=self._target_device())).to(self.device).eval() + + if self.cache_models: + self._model_cache["transformer"] = model + return model def video_decoder(self) -> VideoDecoder: if not hasattr(self, "vae_decoder_builder"): raise ValueError( "Video decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." ) - - return self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models and "video_decoder" in self._model_cache: + return self._model_cache["video_decoder"] + model = self.vae_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models: + self._model_cache["video_decoder"] = model + return model def video_encoder(self) -> VideoEncoder: if not hasattr(self, "vae_encoder_builder"): raise ValueError( "Video encoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." ) - - return self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models and "video_encoder" in self._model_cache: + return self._model_cache["video_encoder"] + model = self.vae_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models: + self._model_cache["video_encoder"] = model + return model def text_encoder(self) -> AVGemmaTextEncoderModel: if not hasattr(self, "text_encoder_builder"): @@ -231,27 +257,55 @@ def text_encoder(self) -> AVGemmaTextEncoderModel: "Text encoder not initialized. Please provide a checkpoint path and gemma root path to the " "ModelLedger constructor." ) - - return self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models and "text_encoder" in self._model_cache: + return self._model_cache["text_encoder"] + model = self.text_encoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models: + self._model_cache["text_encoder"] = model + return model def audio_decoder(self) -> AudioDecoder: if not hasattr(self, "audio_decoder_builder"): raise ValueError( "Audio decoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." ) - - return self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models and "audio_decoder" in self._model_cache: + return self._model_cache["audio_decoder"] + model = self.audio_decoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models: + self._model_cache["audio_decoder"] = model + return model def vocoder(self) -> Vocoder: if not hasattr(self, "vocoder_builder"): raise ValueError( "Vocoder not initialized. Please provide a checkpoint path to the ModelLedger constructor." ) - - return self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models and "vocoder" in self._model_cache: + return self._model_cache["vocoder"] + model = self.vocoder_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models: + self._model_cache["vocoder"] = model + return model def spatial_upsampler(self) -> LatentUpsampler: if not hasattr(self, "upsampler_builder"): raise ValueError("Upsampler not initialized. Please provide upsampler path to the ModelLedger constructor.") + if self.cache_models and "spatial_upsampler" in self._model_cache: + return self._model_cache["spatial_upsampler"] + model = self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + if self.cache_models: + self._model_cache["spatial_upsampler"] = model + return model - return self.upsampler_builder.build(device=self._target_device(), dtype=self.dtype).to(self.device).eval() + def clear_model_cache(self, model_name: str | None = None) -> None: + """Clear cached model instances to free GPU memory. + + Args: + model_name: If provided, only clear this specific model (e.g. "transformer", + "video_encoder", "text_encoder"). Otherwise clear all cached models. + """ + if model_name is not None: + self._model_cache.pop(model_name, None) + else: + self._model_cache.clear()