Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions packages/ltx-core/src/ltx_core/model/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
except ImportError:
flash_attn_interface = None

if flash_attn_interface is not None:
print("Using Flash Attention.")
elif memory_efficient_attention is not None:
print("Using xformers.")
else:
print("Neither Flash Attention nor xFormers xformers is available. Use PytorchAttention.")

class AttentionCallable(Protocol):
def __call__(
Expand Down Expand Up @@ -90,6 +96,7 @@ def __call__(
out = out.reshape(b, -1, heads * dim_head)
return out

_pytorch_attention = PytorchAttention()

class FlashAttention3(AttentionCallable):
def __call__(
Expand All @@ -103,18 +110,15 @@ def __call__(
if flash_attn_interface is None:
raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.")

if mask is not None:
# FA3 does not support arbitrary attention masks — fall back to PyTorch SDPA
return _pytorch_attention(q, k, v, heads, mask)

b, _, dim_head = q.shape
dim_head //= heads

q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v))

if mask is not None:
raise NotImplementedError("Mask is not supported for FlashAttention3")

out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v)
out = out.reshape(b, -1, heads * dim_head)
return out

return out.reshape(b, -1, heads * dim_head)

class AttentionFunction(Enum):
PYTORCH = "pytorch"
Expand All @@ -132,8 +136,13 @@ def to_callable(self) -> AttentionCallable:
elif self is AttentionFunction.FLASH_ATTENTION_3:
return FlashAttention3()
else:
# Default behavior: XFormers if installed else - PyTorch
return XFormersAttention() if memory_efficient_attention is not None else PytorchAttention()
# Default behavior: FA3 > XFormers > PyTorch
if flash_attn_interface is not None:
return FlashAttention3()
elif memory_efficient_attention is not None:
return XFormersAttention()
else:
return PytorchAttention()


class Attention(torch.nn.Module):
Expand Down