diff --git a/packages/ltx-core/src/ltx_core/components/guiders.py b/packages/ltx-core/src/ltx_core/components/guiders.py index ee68eab3..c90e201e 100644 --- a/packages/ltx-core/src/ltx_core/components/guiders.py +++ b/packages/ltx-core/src/ltx_core/components/guiders.py @@ -252,20 +252,36 @@ def calculate( The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg, and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned prediction. + + Guidance math is computed in float32 to avoid bfloat16 precision loss on MPS + (Apple Silicon). The STG delta (cond - perturbed) and rescaling (std ratio) + are particularly sensitive to catastrophic cancellation in low-precision dtypes. """ + dtype = cond.dtype + + # Upcast to float32 for precision — critical on MPS where fast-math and + # bfloat16 rounding cause ~1e-7 errors per op that compound across the + # guidance formula and rescaling (see pytorch/pytorch#84936). + def _f32(t: torch.Tensor | float) -> torch.Tensor | float: + return t.float() if isinstance(t, torch.Tensor) else t + + cond_f, uncond_text_f, uncond_perturbed_f, uncond_modality_f = ( + _f32(cond), _f32(uncond_text), _f32(uncond_perturbed), _f32(uncond_modality), + ) + pred = ( - cond - + (self.params.cfg_scale - 1) * (cond - uncond_text) - + self.params.stg_scale * (cond - uncond_perturbed) - + (self.params.modality_scale - 1) * (cond - uncond_modality) + cond_f + + (self.params.cfg_scale - 1) * (cond_f - uncond_text_f) + + self.params.stg_scale * (cond_f - uncond_perturbed_f) + + (self.params.modality_scale - 1) * (cond_f - uncond_modality_f) ) if self.params.rescale_scale != 0: - factor = cond.std() / pred.std() + factor = cond_f.std() / pred.std() factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale) pred = pred * factor - return pred + return pred.to(dtype) def do_unconditional_generation(self) -> bool: """Returns True if the guider is doing unconditional generation.""" diff --git a/packages/ltx-core/src/ltx_core/layer_streaming.py b/packages/ltx-core/src/ltx_core/layer_streaming.py index 05145afb..01107ade 100644 --- a/packages/ltx-core/src/ltx_core/layer_streaming.py +++ b/packages/ltx-core/src/ltx_core/layer_streaming.py @@ -284,7 +284,10 @@ def teardown(self) -> None: # Drain all in-flight async H2D copies, then release stream resources. # Without the synchronize, clearing the stream/events can trigger # use-after-free at the CUDA driver level. - torch.cuda.synchronize(device=self._target_device) + if torch.cuda.is_available(): + torch.cuda.synchronize(device=self._target_device) + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.synchronize() if self._prefetcher is not None: self._prefetcher.cleanup() self._prefetcher = None diff --git a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py index 00eecf64..ef78fac7 100644 --- a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py +++ b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py @@ -10,6 +10,8 @@ def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda", torch.cuda.current_device()) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") return torch.device("cpu") diff --git a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py index b0d9d241..25e38812 100644 --- a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py +++ b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py @@ -65,7 +65,8 @@ def _enhance( pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0 model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id) - with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]): + fork_devices = [self.model.device] if self.model.device.type == "cuda" else [] + with torch.inference_mode(), torch.random.fork_rng(devices=fork_devices): torch.manual_seed(seed) outputs = self.model.generate( **model_inputs, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py index 17cfb7b4..2ab8635c 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py @@ -107,12 +107,15 @@ def _streaming_model( # returned to the OS. Without this, sequential streaming models # (e.g. text encoder then transformer) exhaust host memory because the # CachingHostAllocator keeps freed blocks cached indefinitely. - torch.cuda.synchronize(device=target_device) - try: - if hasattr(torch._C, "_host_emptyCache"): - torch._C._host_emptyCache() - except Exception: - logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True) + if torch.cuda.is_available(): + torch.cuda.synchronize(device=target_device) + try: + if hasattr(torch._C, "_host_emptyCache"): + torch._C._host_emptyCache() + except Exception: + logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True) + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.synchronize() def _build_state( diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py index da8c0bf0..f7af7776 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py @@ -23,7 +23,10 @@ def gpu_model(model: _M) -> Iterator[_M]: try: yield model finally: - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.synchronize() # .to("meta") releases storage for all parameters/buffers regardless # of their original device (CUDA or CPU). model.to("meta") diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py index 9f53421d..6e40b7f2 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py @@ -30,13 +30,19 @@ def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda", torch.cuda.current_device()) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") return torch.device("cpu") def cleanup_memory() -> None: gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.empty_cache() + torch.mps.synchronize() def _conform_latent_length(latent: torch.Tensor, expected_frames_count: int) -> torch.Tensor: