diff --git a/packages/ltx-core/src/ltx_core/model/transformer/attention.py b/packages/ltx-core/src/ltx_core/model/transformer/attention.py index 23e0447d..95fe5bda 100644 --- a/packages/ltx-core/src/ltx_core/model/transformer/attention.py +++ b/packages/ltx-core/src/ltx_core/model/transformer/attention.py @@ -18,6 +18,12 @@ except ImportError: flash_attn_interface = None +if flash_attn_interface is not None: + print("Using Flash Attention.") +elif memory_efficient_attention is not None: + print("Using xformers.") +else: + print("Neither Flash Attention nor xFormers xformers is available. Use PytorchAttention.") class AttentionCallable(Protocol): def __call__( @@ -90,6 +96,7 @@ def __call__( out = out.reshape(b, -1, heads * dim_head) return out +_pytorch_attention = PytorchAttention() class FlashAttention3(AttentionCallable): def __call__( @@ -103,18 +110,15 @@ def __call__( if flash_attn_interface is None: raise RuntimeError("FlashAttention3 was selected but `FlashAttention3` is not installed.") + if mask is not None: + # FA3 does not support arbitrary attention masks — fall back to PyTorch SDPA + return _pytorch_attention(q, k, v, heads, mask) + b, _, dim_head = q.shape dim_head //= heads - q, k, v = (t.view(b, -1, heads, dim_head) for t in (q, k, v)) - - if mask is not None: - raise NotImplementedError("Mask is not supported for FlashAttention3") - out = flash_attn_interface.flash_attn_func(q.to(v.dtype), k.to(v.dtype), v) - out = out.reshape(b, -1, heads * dim_head) - return out - + return out.reshape(b, -1, heads * dim_head) class AttentionFunction(Enum): PYTORCH = "pytorch" @@ -132,8 +136,13 @@ def to_callable(self) -> AttentionCallable: elif self is AttentionFunction.FLASH_ATTENTION_3: return FlashAttention3() else: - # Default behavior: XFormers if installed else - PyTorch - return XFormersAttention() if memory_efficient_attention is not None else PytorchAttention() + # Default behavior: FA3 > XFormers > PyTorch + if flash_attn_interface is not None: + return FlashAttention3() + elif memory_efficient_attention is not None: + return XFormersAttention() + else: + return PytorchAttention() class Attention(torch.nn.Module):