Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions fast_llm/data/document/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ class BlockModelInput(ModelInput):
lengths: list[int] = None
cumulative_lengths_q: torch.Tensor | None = None
cumulative_lengths_k: torch.Tensor | None = None
max_length_q: torch.Tensor | None = None
max_length_k: torch.Tensor | None = None
max_length_q: int | None = None
max_length_k: int | None = None
min_length_q: int | None = None
min_length_k: int | None = None
document_index_q: torch.Tensor | None = None
document_index_k: torch.Tensor | None = None
position_index: torch.Tensor | None = None
Expand All @@ -44,6 +46,8 @@ def to_kwargs(self) -> dict[str, typing.Any]:
AttentionKwargs.cu_seqlens_k: self.cumulative_lengths_k,
AttentionKwargs.max_seqlen_q: self.max_length_q,
AttentionKwargs.max_seqlen_k: self.max_length_k,
AttentionKwargs.min_seqlen_q: self.min_length_q,
AttentionKwargs.min_seqlen_k: self.min_length_k,
AttentionKwargs.document_index_q: self.document_index_q,
AttentionKwargs.document_index_k: self.document_index_k,
LanguageModelKwargs.position_ids: self.position_index,
Expand Down Expand Up @@ -101,6 +105,8 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo
model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths
if config.return_max_sequence_lengths or config.return_document_index:
model_input.max_length_q, model_input.max_length_k = self.max_lengths
if config.return_min_sequence_lengths:
model_input.min_length_q, model_input.min_length_k = self.min_lengths
if config.return_document_index:
model_input.document_index_q, model_input.document_index_k = self.document_index
if config.return_position_index:
Expand All @@ -118,13 +124,18 @@ def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]:
return cumulative_lengths_q, cumulative_lengths_k

@functools.cached_property
def max_lengths(self) -> tuple[torch.Tensor, torch.Tensor]:
def max_lengths(self) -> tuple[int, int]:
max_length_q = max(self.lengths)
max_length_k = max(max_length_q, self.sequence_k_past + self.lengths[0] - self.first_document_begin)
return (
torch.full((1,), max_length_q, dtype=torch.int32, device=self.device),
torch.full((1,), max_length_k, dtype=torch.int32, device=self.device),
)
return max_length_q, max_length_k

@functools.cached_property
def min_lengths(self) -> tuple[int, int]:
min_length_q = min(self.lengths)
# First doc's K-side length includes the past KV prefix; remaining docs match q-side.
first_length_k = self.sequence_k_past + self.lengths[0] - self.first_document_begin
min_length_k = min(first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else first_length_k
return min_length_q, min_length_k

@functools.cached_property
def document_index(self) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down
1 change: 1 addition & 0 deletions fast_llm/data/document/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig):
distributed: DistributedConfig = Field()
return_cumulative_sequence_lengths: bool = Field(default=False)
return_max_sequence_lengths: bool = Field(default=False)
return_min_sequence_lengths: bool = Field(default=False)
return_document_index: bool = Field(default=False)
return_position_index: bool = Field(default=False)

Expand Down
104 changes: 95 additions & 9 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,14 @@ def __init__(
)
self._implementation = self._config.implementation
if self._implementation == AttentionImplementation.auto:
if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16):
if (
_flash_available
and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16)
and self._config.head_size <= 256
):
self._implementation = AttentionImplementation.flash
else:
self._implementation = AttentionImplementation.backup
self._implementation = AttentionImplementation.sdpa

self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor)
self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim(
Expand Down Expand Up @@ -258,6 +262,68 @@ def _attn_flash(
softmax_scale=self._softmax_scale,
)

def _attn_sdpa(
self,
query: torch.Tensor, # total_q, heads, head_size
key: torch.Tensor, # total_k, head_groups, head_size
value: torch.Tensor, # total_k, head_groups, head_size
kwargs: dict[str, typing.Any],
) -> torch.Tensor: # total_q, heads, head_size
# SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads.
if self._local_heads_per_group > 1:
key = key.repeat_interleave(self._local_heads_per_group, dim=1)
value = value.repeat_interleave(self._local_heads_per_group, dim=1)

sdpa_args: dict[str, typing.Any] = {
"dropout_p": self._config.dropout if self.training else 0.0,
"scale": self._softmax_scale,
}
if query.is_cuda and self._config.window_size is None:
# Wrap each document as its own batch element via nested-jagged so cross-doc masking
# is structural and EFFICIENT skips materializing the attention mask. The dispatch
# otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in
# explicitly keeps the path sync-free.
cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64)
cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64)
query = torch.nested.nested_tensor_from_jagged(
query,
cu_seqlens_q,
min_seqlen=kwargs[AttentionKwargs.min_seqlen_q],
max_seqlen=kwargs[AttentionKwargs.max_seqlen_q],
)
key = torch.nested.nested_tensor_from_jagged(
key,
cu_seqlens_k,
min_seqlen=kwargs[AttentionKwargs.min_seqlen_k],
max_seqlen=kwargs[AttentionKwargs.max_seqlen_k],
)
value = torch.nested.nested_tensor_from_jagged(
value,
cu_seqlens_k,
min_seqlen=kwargs[AttentionKwargs.min_seqlen_k],
max_seqlen=kwargs[AttentionKwargs.max_seqlen_k],
)
sdpa_args["is_causal"] = self._config.causal
else:
# Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects
# nested + is_causal) and on CUDA with sliding window (the nested path can't express
# it). Backup builds the mask as (1, sq, 1, sk); SDPA wants (B, H, sq, sk).
attention_mask = kwargs[AttentionKwargs.attention_mask]
if attention_mask is not None:
attention_mask = attention_mask.transpose(1, 2)
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
sdpa_args["attn_mask"] = attention_mask

output = torch.nn.functional.scaled_dot_product_attention(
query.transpose(1, 2),
key.transpose(1, 2),
value.transpose(1, 2),
**sdpa_args,
).transpose(1, 2)
return output.values() if output.is_nested else output.squeeze(0)

def _apply_norm_with_grad_capture(
self, norm: torch.nn.Module, x: torch.Tensor
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
Expand Down Expand Up @@ -420,6 +486,8 @@ def _forward(
with set_generator(self._distributed.tp_generator):
if self._implementation == AttentionImplementation.flash:
input_ = self._attn_flash(query, key, value, kwargs)
elif self._implementation == AttentionImplementation.sdpa:
input_ = self._attn_sdpa(query, key, value, kwargs)
elif self._implementation == AttentionImplementation.backup:
# TODO: Avoid the flattens.
input_ = self._attn_backup(query, key, value, kwargs)
Expand Down Expand Up @@ -472,7 +540,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c

attention_compute = sequence_q * sequence_k * attn_compute_base

if (not config.hardware) or self._implementation in AttentionImplementation.flash:
if (not config.hardware) or self._implementation in (
AttentionImplementation.flash,
AttentionImplementation.sdpa,
):
# Remove non-causal part. (TODO: Support non-causal)
# TODO: Compute is overestimated without cross-document attention.
attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2
Expand All @@ -498,19 +569,34 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c
)

def get_preprocessing_config(self) -> dict[str, typing.Any]:
return (
{
if self._implementation == AttentionImplementation.flash:
return {
"return_cumulative_sequence_lengths": True,
"return_max_sequence_lengths": True,
"causal": self._config.causal,
}
if self._implementation == AttentionImplementation.flash
else {"return_document_index": True, "causal": self._config.causal}
)
elif (
self._implementation == AttentionImplementation.sdpa
and self._distributed_config.use_cuda
and self._config.window_size is None
):
return {
"return_cumulative_sequence_lengths": True,
"return_max_sequence_lengths": True,
"return_min_sequence_lengths": True,
"causal": self._config.causal,
}
elif self._implementation in (AttentionImplementation.sdpa, AttentionImplementation.backup):
return {"return_document_index": True, "causal": self._config.causal}
else:
raise NotImplementedError(self._implementation)

def preprocess(self, kwargs: dict[str, typing.Any]) -> None:
self._rotary.preprocess(kwargs)
if self._implementation == AttentionImplementation.backup:
if self._implementation == AttentionImplementation.backup or (
self._implementation == AttentionImplementation.sdpa
and (not self._distributed_config.use_cuda or self._config.window_size is not None)
):
self._preprocess_for_backup_attention(kwargs)

def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None:
Expand Down
3 changes: 3 additions & 0 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class MixerKwargs(BlockKwargs):
cu_seqlens_k = "cu_seqlens_k"
max_seqlen_q = "max_seqlen_q"
max_seqlen_k = "max_seqlen_k"
min_seqlen_q = "min_seqlen_q"
min_seqlen_k = "min_seqlen_k"
document_index_q = "document_index_q"
document_index_k = "document_index_k"
position_ids = "position_ids"
Expand All @@ -38,6 +40,7 @@ class AttentionKwargs(MixerKwargs):
class AttentionImplementation(enum.StrEnum):
auto = "auto"
flash = "flash"
sdpa = "sdpa"
backup = "backup"


Expand Down
Loading
Loading