Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions packages/ltx-core/src/ltx_core/components/guiders.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,36 @@ def calculate(
The guider calculates the guidance delta as (scale - 1) * (cond - uncond) for cfg and modality cfg,
and as scale * (cond - uncond) for stg, steering the denoising process away from the unconditioned
prediction.

Guidance math is computed in float32 to avoid bfloat16 precision loss on MPS
(Apple Silicon). The STG delta (cond - perturbed) and rescaling (std ratio)
are particularly sensitive to catastrophic cancellation in low-precision dtypes.
"""
dtype = cond.dtype

# Upcast to float32 for precision — critical on MPS where fast-math and
# bfloat16 rounding cause ~1e-7 errors per op that compound across the
# guidance formula and rescaling (see pytorch/pytorch#84936).
def _f32(t: torch.Tensor | float) -> torch.Tensor | float:
return t.float() if isinstance(t, torch.Tensor) else t

cond_f, uncond_text_f, uncond_perturbed_f, uncond_modality_f = (
_f32(cond), _f32(uncond_text), _f32(uncond_perturbed), _f32(uncond_modality),
)

pred = (
cond
+ (self.params.cfg_scale - 1) * (cond - uncond_text)
+ self.params.stg_scale * (cond - uncond_perturbed)
+ (self.params.modality_scale - 1) * (cond - uncond_modality)
cond_f
+ (self.params.cfg_scale - 1) * (cond_f - uncond_text_f)
+ self.params.stg_scale * (cond_f - uncond_perturbed_f)
+ (self.params.modality_scale - 1) * (cond_f - uncond_modality_f)
)

if self.params.rescale_scale != 0:
factor = cond.std() / pred.std()
factor = cond_f.std() / pred.std()
factor = self.params.rescale_scale * factor + (1 - self.params.rescale_scale)
pred = pred * factor

return pred
return pred.to(dtype)

def do_unconditional_generation(self) -> bool:
"""Returns True if the guider is doing unconditional generation."""
Expand Down
5 changes: 4 additions & 1 deletion packages/ltx-core/src/ltx_core/layer_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ def teardown(self) -> None:
# Drain all in-flight async H2D copies, then release stream resources.
# Without the synchronize, clearing the stream/events can trigger
# use-after-free at the CUDA driver level.
torch.cuda.synchronize(device=self._target_device)
if torch.cuda.is_available():
torch.cuda.synchronize(device=self._target_device)
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.synchronize()
if self._prefetcher is not None:
self._prefetcher.cleanup()
self._prefetcher = None
Expand Down
2 changes: 2 additions & 0 deletions packages/ltx-core/src/ltx_core/loader/fuse_loras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
def _get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda", torch.cuda.current_device())
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def _enhance(
pad_token_id = self.processor.tokenizer.pad_token_id if self.processor.tokenizer.pad_token_id is not None else 0
model_inputs = _pad_inputs_for_attention_alignment(model_inputs, pad_token_id=pad_token_id)

with torch.inference_mode(), torch.random.fork_rng(devices=[self.model.device]):
fork_devices = [self.model.device] if self.model.device.type == "cuda" else []
with torch.inference_mode(), torch.random.fork_rng(devices=fork_devices):
torch.manual_seed(seed)
outputs = self.model.generate(
**model_inputs,
Expand Down
15 changes: 9 additions & 6 deletions packages/ltx-pipelines/src/ltx_pipelines/utils/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,15 @@ def _streaming_model(
# returned to the OS. Without this, sequential streaming models
# (e.g. text encoder then transformer) exhaust host memory because the
# CachingHostAllocator keeps freed blocks cached indefinitely.
torch.cuda.synchronize(device=target_device)
try:
if hasattr(torch._C, "_host_emptyCache"):
torch._C._host_emptyCache()
except Exception:
logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True)
if torch.cuda.is_available():
torch.cuda.synchronize(device=target_device)
try:
if hasattr(torch._C, "_host_emptyCache"):
torch._C._host_emptyCache()
except Exception:
logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True)
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.synchronize()


def _build_state(
Expand Down
5 changes: 4 additions & 1 deletion packages/ltx-pipelines/src/ltx_pipelines/utils/gpu_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def gpu_model(model: _M) -> Iterator[_M]:
try:
yield model
finally:
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.synchronize()
# .to("meta") releases storage for all parameters/buffers regardless
# of their original device (CUDA or CPU).
model.to("meta")
Expand Down
10 changes: 8 additions & 2 deletions packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@
def get_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda", torch.cuda.current_device())
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")


def cleanup_memory() -> None:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
torch.mps.empty_cache()
torch.mps.synchronize()


def _conform_latent_length(latent: torch.Tensor, expected_frames_count: int) -> torch.Tensor:
Expand Down