From e80dd2fcd2b34b91a22f63e3d1e4cb9f167e6b30 Mon Sep 17 00:00:00 2001 From: Mayank Gupta Date: Sun, 5 Apr 2026 12:17:44 +0530 Subject: [PATCH 1/2] Add Apple Silicon (MPS) device support across core and pipeline packages The codebase currently assumes CUDA for device synchronization, memory cleanup, RNG forking, and device detection. This causes crashes on Apple Silicon Macs (MPS backend) where torch.cuda.* APIs are unavailable. Changes: - layer_streaming.py: Guard torch.cuda.synchronize with availability check, add MPS synchronize fallback - fuse_loras.py: Add MPS device detection in _get_device() - base_encoder.py: Skip CUDA-only device list in torch.random.fork_rng on non-CUDA devices (MPS does not support fork_rng device pinning) - blocks.py: Guard CUDA synchronize and host cache cleanup with availability checks, add MPS synchronize fallback - gpu_model.py: Guard torch.cuda.synchronize in model cleanup context manager, add MPS fallback - helpers.py: Add MPS device detection in get_device() and MPS cache cleanup in cleanup_memory() Tested on M4 Max (128GB) with both two-stage and single-stage A2V pipelines generating 640x960 @ 25fps video from image+audio input. --- packages/ltx-core/src/ltx_core/layer_streaming.py | 5 ++++- .../ltx-core/src/ltx_core/loader/fuse_loras.py | 2 ++ .../text_encoders/gemma/encoders/base_encoder.py | 3 ++- .../src/ltx_pipelines/utils/blocks.py | 15 +++++++++------ .../src/ltx_pipelines/utils/gpu_model.py | 5 ++++- .../src/ltx_pipelines/utils/helpers.py | 10 ++++++++-- 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/packages/ltx-core/src/ltx_core/layer_streaming.py b/packages/ltx-core/src/ltx_core/layer_streaming.py index 05145afb..01107ade 100644 --- a/packages/ltx-core/src/ltx_core/layer_streaming.py +++ b/packages/ltx-core/src/ltx_core/layer_streaming.py @@ -284,7 +284,10 @@ def teardown(self) -> None: # Drain all in-flight async H2D copies, then release stream resources. # Without the synchronize, clearing the stream/events can trigger # use-after-free at the CUDA driver level. - torch.cuda.synchronize(device=self._target_device) + if torch.cuda.is_available(): + torch.cuda.synchronize(device=self._target_device) + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.synchronize() if self._prefetcher is not None: self._prefetcher.cleanup() self._prefetcher = None diff --git a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py index 00eecf64..ef78fac7 100644 --- a/packages/ltx-core/src/ltx_core/loader/fuse_loras.py +++ b/packages/ltx-core/src/ltx_core/loader/fuse_loras.py @@ -10,6 +10,8 @@ def _get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda", torch.cuda.current_device()) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") return torch.device("cpu") diff --git a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py index b0d9d241..25e38812 100644 --- a/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py +++ b/packages/ltx-core/src/ltx_core/text_encoders/gemma/encoders/base_encoder.py @@ -65,7 +65,8 @@ def _enhance( pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0 model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id) - with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]): + fork_devices = [self.model.device] if self.model.device.type == "cuda" else [] + with torch.inference_mode(), torch.random.fork_rng(devices=fork_devices): torch.manual_seed(seed) outputs = self.model.generate( **model_inputs, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py index 17cfb7b4..2ab8635c 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py @@ -107,12 +107,15 @@ def _streaming_model( # returned to the OS. Without this, sequential streaming models # (e.g. text encoder then transformer) exhaust host memory because the # CachingHostAllocator keeps freed blocks cached indefinitely. - torch.cuda.synchronize(device=target_device) - try: - if hasattr(torch._C, "_host_emptyCache"): - torch._C._host_emptyCache() - except Exception: - logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True) + if torch.cuda.is_available(): + torch.cuda.synchronize(device=target_device) + try: + if hasattr(torch._C, "_host_emptyCache"): + torch._C._host_emptyCache() + except Exception: + logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True) + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.synchronize() def _build_state( diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py index da8c0bf0..f7af7776 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py @@ -23,7 +23,10 @@ def gpu_model(model: _M) -> Iterator[_M]: try: yield model finally: - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.synchronize() # .to("meta") releases storage for all parameters/buffers regardless # of their original device (CUDA or CPU). model.to("meta") diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py index 9f53421d..6e40b7f2 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py @@ -30,13 +30,19 @@ def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda", torch.cuda.current_device()) + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") return torch.device("cpu") def cleanup_memory() -> None: gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + torch.mps.empty_cache() + torch.mps.synchronize() def _conform_latent_length(latent: torch.Tensor, expected_frames_count: int) -> torch.Tensor: From 7044e31b202b594481ef470f3e1e95363fba2736 Mon Sep 17 00:00:00 2001 From: Mayank Gupta Date: Thu, 9 Apr 2026 00:29:17 +0530 Subject: [PATCH 2/2] Fix STG precision on MPS: compute guidance math in float32 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apple Silicon MPS compiles Metal kernels with fast math enabled (pytorch/pytorch#84936), causing ~1e-7 errors per operation that compound across 30 transformer blocks. The STG delta (cond - ptb) and CFG rescaling (std ratio) are particularly sensitive — the accumulated errors dominate the signal in bfloat16. Fix: upcast all guidance operands to float32 before computing the guidance formula (CFG + STG + modality deltas) and rescaling, then cast back to the original dtype. This eliminates catastrophic cancellation in the subtraction and precision loss in the std ratio. Zero performance impact — the guidance calculation operates on single tensors per step, not per block. Tested on M5 Max 128GB with LTX-2.3 A2V pipeline: - STG=1.0 stg_blocks=[28] now produces correct output on MPS - Lip sync quality improved vs STG=0 - No regression on output quality --- .../src/ltx_core/components/guiders.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/packages/ltx-core/src/ltx_core/components/guiders.py b/packages/ltx-core/src/ltx_core/components/guiders.py index ee68eab3..c90e201e 100644 --- a/packages/ltx-core/src/ltx_core/components/guiders.py +++ b/packages/ltx-core/src/ltx_core/components/guiders.py @@ -252,20 +252,36 @@ def calculate( The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg, and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned prediction. + + Guidance math is computed in float32 to avoid bfloat16 precision loss on MPS + (Apple Silicon). The STG delta (cond - perturbed) and rescaling (std ratio) + are particularly sensitive to catastrophic cancellation in low-precision dtypes. """ + dtype = cond.dtype + + # Upcast to float32 for precision — critical on MPS where fast-math and + # bfloat16 rounding cause ~1e-7 errors per op that compound across the + # guidance formula and rescaling (see pytorch/pytorch#84936). + def _f32(t: torch.Tensor | float) -> torch.Tensor | float: + return t.float() if isinstance(t, torch.Tensor) else t + + cond_f, uncond_text_f, uncond_perturbed_f, uncond_modality_f = ( + _f32(cond), _f32(uncond_text), _f32(uncond_perturbed), _f32(uncond_modality), + ) + pred = ( - cond - + (self.params.cfg_scale - 1) * (cond - uncond_text) - + self.params.stg_scale * (cond - uncond_perturbed) - + (self.params.modality_scale - 1) * (cond - uncond_modality) + cond_f + + (self.params.cfg_scale - 1) * (cond_f - uncond_text_f) + + self.params.stg_scale * (cond_f - uncond_perturbed_f) + + (self.params.modality_scale - 1) * (cond_f - uncond_modality_f) ) if self.params.rescale_scale != 0: - factor = cond.std() / pred.std() + factor = cond_f.std() / pred.std() factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale) pred = pred * factor - return pred + return pred.to(dtype) def do_unconditional_generation(self) -> bool: """Returns True if the guider is doing unconditional generation."""