-
Notifications
You must be signed in to change notification settings - Fork 57
refactor: make flash-linear-attention optional via lazy-loading #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,4 +20,3 @@ | |
| "fused_sigmoid_gating_delta_rule_update", | ||
| "linear_attention_decode", | ||
| ] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,10 +32,8 @@ | |
| from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream | ||
| from cutlass.cute.typing import Float32, Int32, Int64 | ||
| from cutlass.cutlass_dsl import T as _T | ||
| from fla.ops.utils import prepare_chunk_indices, prepare_lens | ||
| from fla.utils import tensor_cache | ||
|
|
||
| from cula.utils import USE_FAST_MATH, assert_blackwell | ||
| from cula.utils import USE_FAST_MATH, assert_blackwell, prepare_chunk_indices, prepare_lens, tensor_cache | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wondering, are these types of changes auto-formatted, or did you do them manually? |
||
|
|
||
|
|
||
| # in FLA, cumsum returns int64 tensor by default | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -233,3 +233,62 @@ def _get_cache_buf(name: str, nbytes: int, device: torch.device) -> torch.Tensor | |
| buf = torch.empty(nbytes, dtype=torch.uint8, device=device) | ||
| _cache_buf[key] = buf | ||
| return buf | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Tensor cache | ||
| # Adapted from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py | ||
| # Original copyright: 2024 The FLA Authors (same Apache 2.0 license) | ||
| # --------------------------------------------------------------------------- | ||
| _CULA_DISABLE_TENSOR_CACHE: bool = os.getenv("CULA_DISABLE_TENSOR_CACHE", "0") == "1" | ||
|
|
||
|
|
||
| def tensor_cache(fn): | ||
| """Single-entry cache for functions with tensor inputs (identity-based).""" | ||
| last_args = None | ||
| last_kwargs = None | ||
| last_result = None | ||
|
|
||
| @functools.wraps(fn) | ||
| def wrapper(*args, **kwargs): | ||
| nonlocal last_args, last_kwargs, last_result | ||
| if _CULA_DISABLE_TENSOR_CACHE: | ||
| return fn(*args, **kwargs) | ||
| if last_args is not None and last_kwargs is not None: | ||
| if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): | ||
| if all(a is b for a, b in zip(args, last_args, strict=False)) and all( | ||
| k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items() | ||
| ): | ||
| return last_result | ||
| result = fn(*args, **kwargs) | ||
| last_args, last_kwargs, last_result = args, kwargs, result | ||
| return result | ||
|
|
||
| return wrapper | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Sequence-length helpers | ||
| # Adapted from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py | ||
| # Original copyright: 2024 The FLA Authors (same Apache 2.0 license) | ||
| # --------------------------------------------------------------------------- | ||
| @tensor_cache | ||
| def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: | ||
| return torch.diff(cu_seqlens) | ||
|
|
||
|
|
||
| @tensor_cache | ||
| def prepare_chunk_indices( | ||
| cu_seqlens: torch.LongTensor, | ||
| chunk_size: int, | ||
| cu_seqlens_cpu: torch.LongTensor | None = None, | ||
| ) -> torch.LongTensor: | ||
| import triton # already available as a transitive dep of cutlass-dsl | ||
|
|
||
| if cu_seqlens_cpu is not None: | ||
| indices = torch.cat( | ||
| [torch.arange(n, device=cu_seqlens.device) for n in triton.cdiv(prepare_lens(cu_seqlens_cpu), chunk_size).tolist()] | ||
| ) | ||
| return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) | ||
|
Comment on lines
+290
to
+292
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential bug and an efficiency issue in the
Removing the explicit indices = torch.cat(
[torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens_cpu), chunk_size).tolist()]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) |
||
| indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) | ||
| return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment is in Chinese, which is inconsistent with the rest of the codebase. For better maintainability and accessibility for all contributors, it is recommended to use English for all comments.
# Map public interfaces to (module path, actual function name)