diff --git a/fast_llm/data/document/block.py b/fast_llm/data/document/block.py index 530be42ea..b9f2e30f2 100644 --- a/fast_llm/data/document/block.py +++ b/fast_llm/data/document/block.py @@ -24,8 +24,10 @@ class BlockModelInput(ModelInput): lengths: list[int] = None cumulative_lengths_q: torch.Tensor | None = None cumulative_lengths_k: torch.Tensor | None = None - max_length_q: torch.Tensor | None = None - max_length_k: torch.Tensor | None = None + max_length_q: int | None = None + max_length_k: int | None = None + min_length_q: int | None = None + min_length_k: int | None = None document_index_q: torch.Tensor | None = None document_index_k: torch.Tensor | None = None position_index: torch.Tensor | None = None @@ -44,6 +46,8 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.cu_seqlens_k: self.cumulative_lengths_k, AttentionKwargs.max_seqlen_q: self.max_length_q, AttentionKwargs.max_seqlen_k: self.max_length_k, + AttentionKwargs.min_seqlen_q: self.min_length_q, + AttentionKwargs.min_seqlen_k: self.min_length_k, AttentionKwargs.document_index_q: self.document_index_q, AttentionKwargs.document_index_k: self.document_index_k, LanguageModelKwargs.position_ids: self.position_index, @@ -101,6 +105,8 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths if config.return_max_sequence_lengths or config.return_document_index: model_input.max_length_q, model_input.max_length_k = self.max_lengths + if config.return_min_sequence_lengths: + model_input.min_length_q, model_input.min_length_k = self.min_lengths if config.return_document_index: model_input.document_index_q, model_input.document_index_k = self.document_index if config.return_position_index: @@ -118,13 +124,18 @@ def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: return cumulative_lengths_q, cumulative_lengths_k @functools.cached_property - def max_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: + def max_lengths(self) -> tuple[int, int]: max_length_q = max(self.lengths) max_length_k = max(max_length_q, self.sequence_k_past + self.lengths[0] - self.first_document_begin) - return ( - torch.full((1,), max_length_q, dtype=torch.int32, device=self.device), - torch.full((1,), max_length_k, dtype=torch.int32, device=self.device), - ) + return max_length_q, max_length_k + + @functools.cached_property + def min_lengths(self) -> tuple[int, int]: + min_length_q = min(self.lengths) + # First doc's K-side length includes the past KV prefix; remaining docs match q-side. + first_length_k = self.sequence_k_past + self.lengths[0] - self.first_document_begin + min_length_k = min(first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else first_length_k + return min_length_q, min_length_k @functools.cached_property def document_index(self) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 352311b51..a90bcdebc 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -25,6 +25,7 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig): distributed: DistributedConfig = Field() return_cumulative_sequence_lengths: bool = Field(default=False) return_max_sequence_lengths: bool = Field(default=False) + return_min_sequence_lengths: bool = Field(default=False) return_document_index: bool = Field(default=False) return_position_index: bool = Field(default=False) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 12f85bf28..9ff1f7846 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -81,10 +81,14 @@ def __init__( ) self._implementation = self._config.implementation if self._implementation == AttentionImplementation.auto: - if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + if ( + _flash_available + and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) + and self._config.head_size <= 256 + ): self._implementation = AttentionImplementation.flash else: - self._implementation = AttentionImplementation.backup + self._implementation = AttentionImplementation.sdpa self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -258,6 +262,68 @@ def _attn_flash( softmax_scale=self._softmax_scale, ) + def _attn_sdpa( + self, + query: torch.Tensor, # total_q, heads, head_size + key: torch.Tensor, # total_k, head_groups, head_size + value: torch.Tensor, # total_k, head_groups, head_size + kwargs: dict[str, typing.Any], + ) -> torch.Tensor: # total_q, heads, head_size + # SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads. + if self._local_heads_per_group > 1: + key = key.repeat_interleave(self._local_heads_per_group, dim=1) + value = value.repeat_interleave(self._local_heads_per_group, dim=1) + + sdpa_args: dict[str, typing.Any] = { + "dropout_p": self._config.dropout if self.training else 0.0, + "scale": self._softmax_scale, + } + if query.is_cuda and self._config.window_size is None: + # Wrap each document as its own batch element via nested-jagged so cross-doc masking + # is structural and EFFICIENT skips materializing the attention mask. The dispatch + # otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in + # explicitly keeps the path sync-free. + cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) + cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) + query = torch.nested.nested_tensor_from_jagged( + query, + cu_seqlens_q, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_q], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_q], + ) + key = torch.nested.nested_tensor_from_jagged( + key, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) + value = torch.nested.nested_tensor_from_jagged( + value, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) + sdpa_args["is_causal"] = self._config.causal + else: + # Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects + # nested + is_causal) and on CUDA with sliding window (the nested path can't express + # it). Backup builds the mask as (1, sq, 1, sk); SDPA wants (B, H, sq, sk). + attention_mask = kwargs[AttentionKwargs.attention_mask] + if attention_mask is not None: + attention_mask = attention_mask.transpose(1, 2) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + sdpa_args["attn_mask"] = attention_mask + + output = torch.nn.functional.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + **sdpa_args, + ).transpose(1, 2) + return output.values() if output.is_nested else output.squeeze(0) + def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: @@ -420,6 +486,8 @@ def _forward( with set_generator(self._distributed.tp_generator): if self._implementation == AttentionImplementation.flash: input_ = self._attn_flash(query, key, value, kwargs) + elif self._implementation == AttentionImplementation.sdpa: + input_ = self._attn_sdpa(query, key, value, kwargs) elif self._implementation == AttentionImplementation.backup: # TODO: Avoid the flattens. input_ = self._attn_backup(query, key, value, kwargs) @@ -472,7 +540,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._implementation in AttentionImplementation.flash: + if (not config.hardware) or self._implementation in ( + AttentionImplementation.flash, + AttentionImplementation.sdpa, + ): # Remove non-causal part. (TODO: Support non-causal) # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 @@ -498,19 +569,34 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) def get_preprocessing_config(self) -> dict[str, typing.Any]: - return ( - { + if self._implementation == AttentionImplementation.flash: + return { "return_cumulative_sequence_lengths": True, "return_max_sequence_lengths": True, "causal": self._config.causal, } - if self._implementation == AttentionImplementation.flash - else {"return_document_index": True, "causal": self._config.causal} - ) + elif ( + self._implementation == AttentionImplementation.sdpa + and self._distributed_config.use_cuda + and self._config.window_size is None + ): + return { + "return_cumulative_sequence_lengths": True, + "return_max_sequence_lengths": True, + "return_min_sequence_lengths": True, + "causal": self._config.causal, + } + elif self._implementation in (AttentionImplementation.sdpa, AttentionImplementation.backup): + return {"return_document_index": True, "causal": self._config.causal} + else: + raise NotImplementedError(self._implementation) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) - if self._implementation == AttentionImplementation.backup: + if self._implementation == AttentionImplementation.backup or ( + self._implementation == AttentionImplementation.sdpa + and (not self._distributed_config.use_cuda or self._config.window_size is not None) + ): self._preprocess_for_backup_attention(kwargs) def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index cc5d80e88..f69e2129d 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -21,6 +21,8 @@ class MixerKwargs(BlockKwargs): cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" + min_seqlen_q = "min_seqlen_q" + min_seqlen_k = "min_seqlen_k" document_index_q = "document_index_q" document_index_k = "document_index_k" position_ids = "position_ids" @@ -38,6 +40,7 @@ class AttentionKwargs(MixerKwargs): class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" + sdpa = "sdpa" backup = "backup" diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index d572816b2..ef96dbf96 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -174,78 +174,101 @@ def expected_output( return torch.nn.functional.linear(attn_out.flatten(1), attention.dense.weight.detach()) -_base_attention_cases = [ +_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] +_LENGTHS_SHORT = [[15], [4, 1, 10]] +_LENGTHS_SINGLE = [[15]] + +_attention_test_cases: list[tuple[AttentionTestConfig, list[int]]] = [] + +# Mask, group, and window base cases — no norms, swept over all length sets. +for name, kwargs in ( ("causal", {"causal": True}), ("noncausal", {"causal": False}), ("window", {"causal": True, "window_size": 4}), ("mqa", {"causal": True, "kv_heads": 1}), ("mha", {"causal": True, "kv_heads": _HEADS}), -] - -_attention_rotary_cases = [ - # Rotary: packing equivalence is skipped for multi-document inputs (packed rotary uses global - # positions; per-sequence reference uses per-doc positions). All three checks run for single-doc inputs. - ("causal_rotary", {"causal": True, "rotary": True}), -] - -_attention_norm_variants = [ - ("no_norm", {}), - ("query_norm", {"query_norm": True}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("both_norms", {"query_norm": True, "key_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -_attention_shared_key_value_cases = [ - ("shared_key_value", {"shared_key_value": True}), - ("shared_key_value_rotary", {"shared_key_value": True, "rotary": True}), - # Gemma 4's full-attention layer combines shared_key_value with ProportionalRotary. +): + for lengths in _LENGTHS_FULL: + _attention_test_cases.append((AttentionTestConfig(name=f"{name}_no_norm", **kwargs), lengths)) + +# Per-head norm variants on causal and shared key/value bases. Rotary bases use single-doc +# inputs because the packed and per-sequence rotary references diverge across boundaries. +for base_name, base_kwargs, variants, length_set in ( + ( + "causal", + {"causal": True}, + ( + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "causal_rotary", + {"causal": True, "rotary": True}, + ( + ("no_norm", {}), + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), + ( + "shared_key_value", + {"shared_key_value": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "shared_key_value_rotary", + {"shared_key_value": True, "rotary": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), ( "shared_key_value_proportional_rotary", {"shared_key_value": True, "rotary": True, "rotary_partial_rotary_factor": 0.5}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, ), -] - -_attention_shared_key_value_norm_variants = [ - ("no_norm", {}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -# Norms apply per-head and don't interact with mask/group structure, so we test all norm -# variants on a single base (causal) instead of crossing every norm with every base. -# Lengths matter for packing/flash equivalence checks, so we sweep all lengths on the -# base × no_norm cases (where seq layout interacts most with attention math). -_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] -_LENGTHS_SHORT = [[15], [4, 1, 10]] -_LENGTHS_SINGLE = [[15]] - - -def _build_test_cases() -> list[tuple[AttentionTestConfig, list[int]]]: - cases: list[tuple[AttentionTestConfig, list[int]]] = [] - for base_name, base_kwargs in _base_attention_cases: - config = AttentionTestConfig(name=f"{base_name}_no_norm", **base_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_FULL) - for variant_name, variant_kwargs in _attention_norm_variants: - if variant_name == "no_norm": - continue - config = AttentionTestConfig(name=f"causal_{variant_name}", causal=True, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SHORT) - for base_name, base_kwargs in _attention_rotary_cases: - for variant_name, variant_kwargs in _attention_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SINGLE) - for base_name, base_kwargs in _attention_shared_key_value_cases: - lengths_set = _LENGTHS_SINGLE if base_kwargs.get("rotary") else _LENGTHS_SHORT - for variant_name, variant_kwargs in _attention_shared_key_value_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in lengths_set) - return cases - +): + for variant_name, variant_kwargs in variants: + for lengths in length_set: + _attention_test_cases.append( + ( + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs), + lengths, + ) + ) -_attention_test_cases = _build_test_cases() +# head_size > 256 — exercises the SDPA-only regime (flash caps at 256). +for name, kwargs in ( + ("large_head_causal", {"causal": True, "head_size": 320}), + ("large_head_mqa", {"causal": True, "head_size": 320, "kv_heads": 1}), +): + for lengths in _LENGTHS_SHORT: + _attention_test_cases.append((AttentionTestConfig(name=name, **kwargs), lengths)) def _run_per_seq_reference( @@ -357,50 +380,56 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: Assert.rms_close_relative(param.grad_buffer, grad_ref, 1e-5, 1e-7, msg=name) stage.reset_gradients() - # Flash equivalence check: packed flash output must match per-sequence bfloat16 backup reference. - if _flash_available: - distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) - distributed_bf16 = Distributed(distributed_config_bf16) + # Flash and SDPA equivalence checks: each implementation's packed bfloat16 output must + # match a per-sequence bfloat16 backup reference. + if not torch.cuda.is_available(): + return - attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( - distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False - ) - stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) - for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): - param_bf16.data.copy_(param_f32.data) - - hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) - out_ref_bf16 = _run_per_seq_reference( - attention_backup_bf16, - stage_backup_bf16, - distributed_config_bf16, - hidden_states_bf16, - lengths, - device, - with_backward=False, - ) + distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) + distributed_bf16 = Distributed(distributed_config_bf16) + + attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( + distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) + for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): + param_bf16.data.copy_(param_f32.data) + + hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) + out_ref_bf16 = _run_per_seq_reference( + attention_backup_bf16, + stage_backup_bf16, + distributed_config_bf16, + hidden_states_bf16, + lengths, + device, + with_backward=False, + ) - attention_flash: Attention = config.get_attention_config("flash").get_layer( + def _check_packed(implementation: str) -> None: + attention_impl: Attention = config.get_attention_config(implementation).get_layer( distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False ) - stage_flash = get_stage([attention_flash], distributed_bf16) - for param_flash, param_f32 in zip(attention_flash.parameters(), attention.parameters(), strict=True): - param_flash.data.copy_(param_f32.data) - - (model_input_flash,) = LanguageModelBatch( + stage_impl = get_stage([attention_impl], distributed_bf16) + for param_impl, param_f32 in zip(attention_impl.parameters(), attention.parameters(), strict=True): + param_impl.data.copy_(param_f32.data) + (model_input,) = LanguageModelBatch( tokens=torch.empty(num_tokens, dtype=torch.int64, device=device), lengths=lengths ).get_model_inputs( LanguageModelBatchPreprocessingConfig( distributed=distributed_config_bf16, predicted_tokens=0, - **attention_flash.get_preprocessing_config(), + **attention_impl.get_preprocessing_config(), ) ) - kwargs_flash = model_input_flash.to_kwargs() - attention_flash.preprocess(kwargs_flash) - out_flash, _ = stage_flash.forward(hidden_states_bf16, kwargs_flash) - - Assert.rms_close_relative(out_flash, out_ref_bf16, 5e-3, 1e-7) + kwargs_impl = model_input.to_kwargs() + attention_impl.preprocess(kwargs_impl) + out_impl, _ = stage_impl.forward(hidden_states_bf16, kwargs_impl) + Assert.rms_close_relative(out_impl, out_ref_bf16, 5e-3, 1e-7) + + if _flash_available and config.head_size <= 256: + _check_packed("flash") + _check_packed("sdpa") @pytest.mark.slow