From c24647ac163107e6ccf276cbe300d70969d75c22 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 May 2026 19:12:39 -0400 Subject: [PATCH 1/6] Add SDPA attention for head_size > 256 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Flash-attn errors out at head_size > 256, so head_size=512 models cannot train without materializing the full O(S²) attention matrix via the backup path. Add `AttentionImplementation.sdpa` using `torch.nested` to bridge the packed-varlen layout to SDPA's batched signature, pinning the EFFICIENT backend. K/V are manually repeat_interleaved to match Q heads because the fused kernels reject broadcasted GQA inputs. Auto-fallback: flash when bf16/fp16 + head_size <= 256 + flash is available; backup for windowed attention (the sdpa path does not support sliding window); sdpa otherwise. Tests: SDPA equivalence check parallel to flash, gated on CUDA + bf16; two head_size=320 cases exercising the SDPA-only regime; refactored parametrization from `_build_test_cases` plus single-use variant lists into a few inline for-loops at module level. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 65 +++++++- fast_llm/layers/attention/config.py | 1 + tests/layers/test_attention.py | 220 ++++++++++++++----------- 3 files changed, 183 insertions(+), 103 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 12f85bf28..b9a0cc944 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -81,10 +81,19 @@ def __init__( ) self._implementation = self._config.implementation if self._implementation == AttentionImplementation.auto: - if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + if ( + _flash_available + and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) + and self._config.head_size <= 256 + ): self._implementation = AttentionImplementation.flash - else: + elif self._config.window_size is not None: + # SDPA path doesn't support sliding window; backup is the only fallback that does. self._implementation = AttentionImplementation.backup + else: + self._implementation = AttentionImplementation.sdpa + if self._implementation == AttentionImplementation.sdpa: + assert self._config.window_size is None, "SDPA implementation does not support sliding window." self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -258,6 +267,38 @@ def _attn_flash( softmax_scale=self._softmax_scale, ) + def _attn_sdpa( + self, + query: torch.Tensor, # total_q, heads, head_size + key: torch.Tensor, # total_k, head_groups, head_size + value: torch.Tensor, # total_k, head_groups, head_size + kwargs: dict[str, typing.Any], + ) -> torch.Tensor: # total_q, heads, head_size + # SDPA's EFFICIENT backend (the only one that supports head_size > 256) requires + # Q/K/V to have the same num_heads, so we materialize K/V across query heads. + # Wrap as nested-jagged to give SDPA the per-document mask via batch elements, + # avoiding the pack→pad→gather dance. + if self._local_heads_per_group > 1: + key = key.repeat_interleave(self._local_heads_per_group, dim=1) + value = value.repeat_interleave(self._local_heads_per_group, dim=1) + cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) + cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) + query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) + key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) + value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + output_nested = torch.nn.functional.scaled_dot_product_attention( + query_nested.transpose(1, 2), + key_nested.transpose(1, 2), + value_nested.transpose(1, 2), + is_causal=self._config.causal, + dropout_p=self._config.dropout if self.training else 0.0, + scale=self._softmax_scale, + ).transpose(1, 2) + + return output_nested.values() + def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: @@ -420,6 +461,8 @@ def _forward( with set_generator(self._distributed.tp_generator): if self._implementation == AttentionImplementation.flash: input_ = self._attn_flash(query, key, value, kwargs) + elif self._implementation == AttentionImplementation.sdpa: + input_ = self._attn_sdpa(query, key, value, kwargs) elif self._implementation == AttentionImplementation.backup: # TODO: Avoid the flattens. input_ = self._attn_backup(query, key, value, kwargs) @@ -472,7 +515,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._implementation in AttentionImplementation.flash: + if (not config.hardware) or self._implementation in ( + AttentionImplementation.flash, + AttentionImplementation.sdpa, + ): # Remove non-causal part. (TODO: Support non-causal) # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 @@ -498,15 +544,18 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c ) def get_preprocessing_config(self) -> dict[str, typing.Any]: - return ( - { + if self._implementation == AttentionImplementation.flash: + return { "return_cumulative_sequence_lengths": True, "return_max_sequence_lengths": True, "causal": self._config.causal, } - if self._implementation == AttentionImplementation.flash - else {"return_document_index": True, "causal": self._config.causal} - ) + elif self._implementation == AttentionImplementation.sdpa: + return {"return_cumulative_sequence_lengths": True, "causal": self._config.causal} + elif self._implementation == AttentionImplementation.backup: + return {"return_document_index": True, "causal": self._config.causal} + else: + raise NotImplementedError(self._implementation) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index cc5d80e88..69aa4f484 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -38,6 +38,7 @@ class AttentionKwargs(MixerKwargs): class AttentionImplementation(enum.StrEnum): auto = "auto" flash = "flash" + sdpa = "sdpa" backup = "backup" diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index d572816b2..39a4d5d58 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -174,78 +174,101 @@ def expected_output( return torch.nn.functional.linear(attn_out.flatten(1), attention.dense.weight.detach()) -_base_attention_cases = [ +_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] +_LENGTHS_SHORT = [[15], [4, 1, 10]] +_LENGTHS_SINGLE = [[15]] + +_attention_test_cases: list[tuple[AttentionTestConfig, list[int]]] = [] + +# Mask, group, and window base cases — no norms, swept over all length sets. +for name, kwargs in ( ("causal", {"causal": True}), ("noncausal", {"causal": False}), ("window", {"causal": True, "window_size": 4}), ("mqa", {"causal": True, "kv_heads": 1}), ("mha", {"causal": True, "kv_heads": _HEADS}), -] - -_attention_rotary_cases = [ - # Rotary: packing equivalence is skipped for multi-document inputs (packed rotary uses global - # positions; per-sequence reference uses per-doc positions). All three checks run for single-doc inputs. - ("causal_rotary", {"causal": True, "rotary": True}), -] - -_attention_norm_variants = [ - ("no_norm", {}), - ("query_norm", {"query_norm": True}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("both_norms", {"query_norm": True, "key_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -_attention_shared_key_value_cases = [ - ("shared_key_value", {"shared_key_value": True}), - ("shared_key_value_rotary", {"shared_key_value": True, "rotary": True}), - # Gemma 4's full-attention layer combines shared_key_value with ProportionalRotary. +): + for lengths in _LENGTHS_FULL: + _attention_test_cases.append((AttentionTestConfig(name=f"{name}_no_norm", **kwargs), lengths)) + +# Per-head norm variants on causal and shared key/value bases. Rotary bases use single-doc +# inputs because the packed and per-sequence rotary references diverge across boundaries. +for base_name, base_kwargs, variants, length_set in ( + ( + "causal", + {"causal": True}, + ( + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "causal_rotary", + {"causal": True, "rotary": True}, + ( + ("no_norm", {}), + ("query_norm", {"query_norm": True}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("both_norms", {"query_norm": True, "key_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), + ( + "shared_key_value", + {"shared_key_value": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SHORT, + ), + ( + "shared_key_value_rotary", + {"shared_key_value": True, "rotary": True}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, + ), ( "shared_key_value_proportional_rotary", {"shared_key_value": True, "rotary": True, "rotary_partial_rotary_factor": 0.5}, + ( + ("no_norm", {}), + ("key_norm", {"key_norm": True}), + ("value_norm", {"value_norm": True}), + ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), + ), + _LENGTHS_SINGLE, ), -] - -_attention_shared_key_value_norm_variants = [ - ("no_norm", {}), - ("key_norm", {"key_norm": True}), - ("value_norm", {"value_norm": True}), - ("all_norms", {"query_norm": True, "key_norm": True, "value_norm": True}), -] - -# Norms apply per-head and don't interact with mask/group structure, so we test all norm -# variants on a single base (causal) instead of crossing every norm with every base. -# Lengths matter for packing/flash equivalence checks, so we sweep all lengths on the -# base × no_norm cases (where seq layout interacts most with attention math). -_LENGTHS_FULL = [[15], [6, 9], [4, 1, 10], [20, 32, 10, 11, 9, 18]] -_LENGTHS_SHORT = [[15], [4, 1, 10]] -_LENGTHS_SINGLE = [[15]] - - -def _build_test_cases() -> list[tuple[AttentionTestConfig, list[int]]]: - cases: list[tuple[AttentionTestConfig, list[int]]] = [] - for base_name, base_kwargs in _base_attention_cases: - config = AttentionTestConfig(name=f"{base_name}_no_norm", **base_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_FULL) - for variant_name, variant_kwargs in _attention_norm_variants: - if variant_name == "no_norm": - continue - config = AttentionTestConfig(name=f"causal_{variant_name}", causal=True, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SHORT) - for base_name, base_kwargs in _attention_rotary_cases: - for variant_name, variant_kwargs in _attention_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in _LENGTHS_SINGLE) - for base_name, base_kwargs in _attention_shared_key_value_cases: - lengths_set = _LENGTHS_SINGLE if base_kwargs.get("rotary") else _LENGTHS_SHORT - for variant_name, variant_kwargs in _attention_shared_key_value_norm_variants: - config = AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs) - cases.extend((config, lengths) for lengths in lengths_set) - return cases - +): + for variant_name, variant_kwargs in variants: + for lengths in length_set: + _attention_test_cases.append( + ( + AttentionTestConfig(name=f"{base_name}_{variant_name}", **base_kwargs, **variant_kwargs), + lengths, + ) + ) -_attention_test_cases = _build_test_cases() +# head_size > 256 — exercises the SDPA-only regime (flash caps at 256). +for name, kwargs in ( + ("large_head_causal", {"causal": True, "head_size": 320}), + ("large_head_mqa", {"causal": True, "head_size": 320, "kv_heads": 1}), +): + for lengths in _LENGTHS_SHORT: + _attention_test_cases.append((AttentionTestConfig(name=name, **kwargs), lengths)) def _run_per_seq_reference( @@ -357,50 +380,57 @@ def _test_attention(config: AttentionTestConfig, lengths: list[int]) -> None: Assert.rms_close_relative(param.grad_buffer, grad_ref, 1e-5, 1e-7, msg=name) stage.reset_gradients() - # Flash equivalence check: packed flash output must match per-sequence bfloat16 backup reference. - if _flash_available: - distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) - distributed_bf16 = Distributed(distributed_config_bf16) + # Flash and SDPA equivalence checks: each implementation's packed bfloat16 output must + # match a per-sequence bfloat16 backup reference. + if not torch.cuda.is_available(): + return - attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( - distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False - ) - stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) - for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): - param_bf16.data.copy_(param_f32.data) - - hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) - out_ref_bf16 = _run_per_seq_reference( - attention_backup_bf16, - stage_backup_bf16, - distributed_config_bf16, - hidden_states_bf16, - lengths, - device, - with_backward=False, - ) + distributed_config_bf16 = DistributedConfig(compute_dtype=DataType.bfloat16, use_cuda=True) + distributed_bf16 = Distributed(distributed_config_bf16) + + attention_backup_bf16: Attention = config.get_attention_config("backup").get_layer( + distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False + ) + stage_backup_bf16 = get_stage([attention_backup_bf16], distributed_bf16) + for param_bf16, param_f32 in zip(attention_backup_bf16.parameters(), attention.parameters(), strict=True): + param_bf16.data.copy_(param_f32.data) + + hidden_states_bf16 = hidden_states.detach().to(torch.bfloat16) + out_ref_bf16 = _run_per_seq_reference( + attention_backup_bf16, + stage_backup_bf16, + distributed_config_bf16, + hidden_states_bf16, + lengths, + device, + with_backward=False, + ) - attention_flash: Attention = config.get_attention_config("flash").get_layer( + def _check_packed(implementation: str) -> None: + attention_impl: Attention = config.get_attention_config(implementation).get_layer( distributed_config_bf16, hidden_dim, lr_scale=None, peft=None, return_bias=False ) - stage_flash = get_stage([attention_flash], distributed_bf16) - for param_flash, param_f32 in zip(attention_flash.parameters(), attention.parameters(), strict=True): - param_flash.data.copy_(param_f32.data) - - (model_input_flash,) = LanguageModelBatch( + stage_impl = get_stage([attention_impl], distributed_bf16) + for param_impl, param_f32 in zip(attention_impl.parameters(), attention.parameters(), strict=True): + param_impl.data.copy_(param_f32.data) + (model_input,) = LanguageModelBatch( tokens=torch.empty(num_tokens, dtype=torch.int64, device=device), lengths=lengths ).get_model_inputs( LanguageModelBatchPreprocessingConfig( distributed=distributed_config_bf16, predicted_tokens=0, - **attention_flash.get_preprocessing_config(), + **attention_impl.get_preprocessing_config(), ) ) - kwargs_flash = model_input_flash.to_kwargs() - attention_flash.preprocess(kwargs_flash) - out_flash, _ = stage_flash.forward(hidden_states_bf16, kwargs_flash) - - Assert.rms_close_relative(out_flash, out_ref_bf16, 5e-3, 1e-7) + kwargs_impl = model_input.to_kwargs() + attention_impl.preprocess(kwargs_impl) + out_impl, _ = stage_impl.forward(hidden_states_bf16, kwargs_impl) + Assert.rms_close_relative(out_impl, out_ref_bf16, 5e-3, 1e-7) + + if _flash_available and config.head_size <= 256: + _check_packed("flash") + if config.window_size is None: + _check_packed("sdpa") @pytest.mark.slow From 23412b70f41bdb15bb605a033052b1d38ecf19dd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 14:05:43 -0400 Subject: [PATCH 2/6] Route auto-fallback to backup on CPU The SDPA path uses `nested_tensor_from_jagged + is_causal=True` which has no viable backend on CPU (math rejects nested + is_causal; the fused EFFICIENT/Flash backends are CUDA-only). Auto previously routed CPU runs through SDPA and they would crash; route them to backup. Also widens the SDPA branch to fp32 explicitly: the EFFICIENT backend engages on CUDA across bf16/fp16/fp32, and benchmarking confirms it beats backup on memory at every length and matches it on time at seq_len >= 4096 (backup grows quadratically; SDPA stays near constant). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index b9a0cc944..e3e461cf5 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -87,11 +87,13 @@ def __init__( and self._config.head_size <= 256 ): self._implementation = AttentionImplementation.flash - elif self._config.window_size is not None: - # SDPA path doesn't support sliding window; backup is the only fallback that does. - self._implementation = AttentionImplementation.backup - else: + elif self._distributed_config.use_cuda and self._config.window_size is None: + # SDPA's EFFICIENT backend handles every dtype on CUDA; on CPU the + # nested + is_causal path has no viable backend, and SDPA does not + # support sliding window so windowed runs need backup either way. self._implementation = AttentionImplementation.sdpa + else: + self._implementation = AttentionImplementation.backup if self._implementation == AttentionImplementation.sdpa: assert self._config.window_size is None, "SDPA implementation does not support sliding window." From bd17da3787be9e6db593623fe8695f913bc44451 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 14:41:07 -0400 Subject: [PATCH 3/6] Use backup mask in SDPA fallback paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous attempt routed CPU and windowed configurations to backup because the nested + is_causal=True form has no viable backend on CPU and cannot express sliding window. SDPA actually works fine in those cases when given an explicit attn_mask: backup's preprocessing already builds the combined causal+document mask (and threads sliding window into it), so the SDPA path can reuse it as-is. CUDA without a window keeps the nested + is_causal path so EFFICIENT runs without materializing the mask. CUDA with a window and CPU runs both fall through to dense + attn_mask, which lets MATH engage on CPU and reuses the windowed mask on CUDA. Auto-fallback simplifies to flash-or-sdpa: SDPA now covers every case backup used to (CPU, windowed without flash, head_size > 256). Verified on H100 bf16 head_size=512 that the dense + attn_mask form also engages EFFICIENT (peak 323 MiB vs 319 MiB for is_causal — the 4 MiB delta is the mask itself). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 78 ++++++++++++++++---------- tests/layers/test_attention.py | 3 +- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index e3e461cf5..605074222 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -87,15 +87,8 @@ def __init__( and self._config.head_size <= 256 ): self._implementation = AttentionImplementation.flash - elif self._distributed_config.use_cuda and self._config.window_size is None: - # SDPA's EFFICIENT backend handles every dtype on CUDA; on CPU the - # nested + is_causal path has no viable backend, and SDPA does not - # support sliding window so windowed runs need backup either way. - self._implementation = AttentionImplementation.sdpa else: - self._implementation = AttentionImplementation.backup - if self._implementation == AttentionImplementation.sdpa: - assert self._config.window_size is None, "SDPA implementation does not support sliding window." + self._implementation = AttentionImplementation.sdpa self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -276,30 +269,48 @@ def _attn_sdpa( value: torch.Tensor, # total_k, head_groups, head_size kwargs: dict[str, typing.Any], ) -> torch.Tensor: # total_q, heads, head_size - # SDPA's EFFICIENT backend (the only one that supports head_size > 256) requires - # Q/K/V to have the same num_heads, so we materialize K/V across query heads. - # Wrap as nested-jagged to give SDPA the per-document mask via batch elements, - # avoiding the pack→pad→gather dance. + # SDPA's fused kernels require Q/K/V to share heads, so we expand K/V across query heads. if self._local_heads_per_group > 1: key = key.repeat_interleave(self._local_heads_per_group, dim=1) value = value.repeat_interleave(self._local_heads_per_group, dim=1) - cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) - cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) - query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) - key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) - value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) - - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): - output_nested = torch.nn.functional.scaled_dot_product_attention( - query_nested.transpose(1, 2), - key_nested.transpose(1, 2), - value_nested.transpose(1, 2), - is_causal=self._config.causal, + + if query.is_cuda and self._config.window_size is None: + # Most-efficient path: nested-jagged + is_causal lets EFFICIENT skip materializing + # the attention mask. Document boundaries are encoded by the per-doc batch elements. + cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) + cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) + query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) + key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) + value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + output_nested = torch.nn.functional.scaled_dot_product_attention( + query_nested.transpose(1, 2), + key_nested.transpose(1, 2), + value_nested.transpose(1, 2), + is_causal=self._config.causal, + dropout_p=self._config.dropout if self.training else 0.0, + scale=self._softmax_scale, + ).transpose(1, 2) + return output_nested.values() + + # CPU MATH rejects nested + is_causal, and the nested path can't express sliding window. + # Both fall back on the same dense + attn_mask form, reusing backup's preprocessed mask. + # Backup builds it as (1, sq, 1, sk) for its head-grouped layout; SDPA wants (B, H, sq, sk). + attention_mask = kwargs[AttentionKwargs.attention_mask] + if attention_mask is not None: + attention_mask = attention_mask.transpose(1, 2) + return ( + torch.nn.functional.scaled_dot_product_attention( + query.unsqueeze(0).transpose(1, 2), + key.unsqueeze(0).transpose(1, 2), + value.unsqueeze(0).transpose(1, 2), + attn_mask=attention_mask, dropout_p=self._config.dropout if self.training else 0.0, scale=self._softmax_scale, - ).transpose(1, 2) - - return output_nested.values() + ) + .transpose(1, 2) + .squeeze(0) + ) def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor @@ -552,16 +563,23 @@ def get_preprocessing_config(self) -> dict[str, typing.Any]: "return_max_sequence_lengths": True, "causal": self._config.causal, } - elif self._implementation == AttentionImplementation.sdpa: + elif ( + self._implementation == AttentionImplementation.sdpa + and self._distributed_config.use_cuda + and self._config.window_size is None + ): return {"return_cumulative_sequence_lengths": True, "causal": self._config.causal} - elif self._implementation == AttentionImplementation.backup: + elif self._implementation in (AttentionImplementation.sdpa, AttentionImplementation.backup): return {"return_document_index": True, "causal": self._config.causal} else: raise NotImplementedError(self._implementation) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(kwargs) - if self._implementation == AttentionImplementation.backup: + if self._implementation == AttentionImplementation.backup or ( + self._implementation == AttentionImplementation.sdpa + and (not self._distributed_config.use_cuda or self._config.window_size is not None) + ): self._preprocess_for_backup_attention(kwargs) def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> None: diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 39a4d5d58..ef96dbf96 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -429,8 +429,7 @@ def _check_packed(implementation: str) -> None: if _flash_available and config.head_size <= 256: _check_packed("flash") - if config.window_size is None: - _check_packed("sdpa") + _check_packed("sdpa") @pytest.mark.slow From 4c99e884e5cc1edf76c0447788220d5a8ad8eee1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 15:54:25 -0400 Subject: [PATCH 4/6] Unify the two SDPA paths around a single F.scaled_dot_product_attention call MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The CUDA-no-window and dense-mask paths shared the K/V expansion, the SDPA call signature (dropout + scale), and the (B, H, S, D) layout requirement. Lift those out: rebind query/key/value to either nested-jagged or unsqueeze(0)'d 4D tensors in the per-path setup, build an `sdpa_args` dict that adds `is_causal=...` for nested or `attn_mask=...` for dense, then make a single SDPA call that works for both. The unwrap branches on `output.is_nested`. Also drops the explicit EFFICIENT_ATTENTION pin from the nested path — nested + is_causal=True has no other viable backend (MATH and Flash both reject it), so the auto pick lands on EFFICIENT or the call errors out either way. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 63 ++++++++++++-------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 605074222..2024cec45 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -274,43 +274,38 @@ def _attn_sdpa( key = key.repeat_interleave(self._local_heads_per_group, dim=1) value = value.repeat_interleave(self._local_heads_per_group, dim=1) + sdpa_args: dict[str, typing.Any] = { + "dropout_p": self._config.dropout if self.training else 0.0, + "scale": self._softmax_scale, + } if query.is_cuda and self._config.window_size is None: - # Most-efficient path: nested-jagged + is_causal lets EFFICIENT skip materializing - # the attention mask. Document boundaries are encoded by the per-doc batch elements. + # Wrap each document as its own batch element via nested-jagged so cross-doc masking + # is structural and EFFICIENT skips materializing the attention mask. cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) - query_nested = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) - key_nested = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) - value_nested = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): - output_nested = torch.nn.functional.scaled_dot_product_attention( - query_nested.transpose(1, 2), - key_nested.transpose(1, 2), - value_nested.transpose(1, 2), - is_causal=self._config.causal, - dropout_p=self._config.dropout if self.training else 0.0, - scale=self._softmax_scale, - ).transpose(1, 2) - return output_nested.values() - - # CPU MATH rejects nested + is_causal, and the nested path can't express sliding window. - # Both fall back on the same dense + attn_mask form, reusing backup's preprocessed mask. - # Backup builds it as (1, sq, 1, sk) for its head-grouped layout; SDPA wants (B, H, sq, sk). - attention_mask = kwargs[AttentionKwargs.attention_mask] - if attention_mask is not None: - attention_mask = attention_mask.transpose(1, 2) - return ( - torch.nn.functional.scaled_dot_product_attention( - query.unsqueeze(0).transpose(1, 2), - key.unsqueeze(0).transpose(1, 2), - value.unsqueeze(0).transpose(1, 2), - attn_mask=attention_mask, - dropout_p=self._config.dropout if self.training else 0.0, - scale=self._softmax_scale, - ) - .transpose(1, 2) - .squeeze(0) - ) + query = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) + key = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) + value = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + sdpa_args["is_causal"] = self._config.causal + else: + # Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects + # nested + is_causal) and on CUDA with sliding window (the nested path can't express + # it). Backup builds the mask as (1, sq, 1, sk); SDPA wants (B, H, sq, sk). + attention_mask = kwargs[AttentionKwargs.attention_mask] + if attention_mask is not None: + attention_mask = attention_mask.transpose(1, 2) + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + sdpa_args["attn_mask"] = attention_mask + + output = torch.nn.functional.scaled_dot_product_attention( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + **sdpa_args, + ).transpose(1, 2) + return output.values() if output.is_nested else output.squeeze(0) def _apply_norm_with_grad_capture( self, norm: torch.nn.Module, x: torch.Tensor From ffa8b7ec34b9ea9d817fe4cb99e9894b6369497d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 16:31:09 -0400 Subject: [PATCH 5/6] Note SDPA nested dispatch sync cost in the attention comment The nested path floors per-call wall around 6 ms because SDPA's nested dispatch pulls `max_seqlen` / `min_seqlen` to host (5 cudaMemcpyAsync DtoH + cudaStreamSynchronize per call). Sync count is fixed regardless of num_docs, so the path stays much faster than dense+mask in varlen training; the comment just makes the cost discoverable. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/attention/attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 2024cec45..57b826ab2 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -280,7 +280,9 @@ def _attn_sdpa( } if query.is_cuda and self._config.window_size is None: # Wrap each document as its own batch element via nested-jagged so cross-doc masking - # is structural and EFFICIENT skips materializing the attention mask. + # is structural and EFFICIENT skips materializing the attention mask. SDPA's nested + # dispatch reads `max_seqlen`/`min_seqlen` to host (5 cudaMemcpyAsync DtoH per call), + # which floors per-call wall at ~6 ms; still much faster than dense+mask in varlen. cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) query = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) From f6958b5aa8a5fb5a77c847e89f9d71fa1a339148 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 May 2026 16:57:23 -0400 Subject: [PATCH 6/6] Pre-compute min/max seq lengths so SDPA's nested path doesn't sync MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PyTorch's nested SDPA dispatch reads `max_seqlen` and `min_seqlen` to host on every call (5 cudaMemcpyAsync DtoH + cudaStreamSynchronize per call) when they aren't supplied. Both are trivially derivable from the Python `lengths` list at preprocessing time, so we compute them as plain ints, thread them through `BlockModelInput` / kwargs, and pass them to `nested_tensor_from_jagged`. While doing this, drop the `torch.full((1,), ..., device=...)` wrap on `max_lengths` — the value was always a Python int, and flash accepts an int directly (verified). The auto-device-move on the `Document` base class only moves Tensor fields, so plain ints pass through to_kwargs untouched. Sync events per call (Llama-7B-shape, 4 docs × 4096): before: 5 cudaStreamSynchronize + 5 Memcpy DtoH after: 0 Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/data/document/block.py | 25 +++++++++++++------ fast_llm/data/document/config.py | 1 + fast_llm/layers/attention/attention.py | 34 ++++++++++++++++++++------ fast_llm/layers/attention/config.py | 2 ++ 4 files changed, 48 insertions(+), 14 deletions(-) diff --git a/fast_llm/data/document/block.py b/fast_llm/data/document/block.py index 530be42ea..b9f2e30f2 100644 --- a/fast_llm/data/document/block.py +++ b/fast_llm/data/document/block.py @@ -24,8 +24,10 @@ class BlockModelInput(ModelInput): lengths: list[int] = None cumulative_lengths_q: torch.Tensor | None = None cumulative_lengths_k: torch.Tensor | None = None - max_length_q: torch.Tensor | None = None - max_length_k: torch.Tensor | None = None + max_length_q: int | None = None + max_length_k: int | None = None + min_length_q: int | None = None + min_length_k: int | None = None document_index_q: torch.Tensor | None = None document_index_k: torch.Tensor | None = None position_index: torch.Tensor | None = None @@ -44,6 +46,8 @@ def to_kwargs(self) -> dict[str, typing.Any]: AttentionKwargs.cu_seqlens_k: self.cumulative_lengths_k, AttentionKwargs.max_seqlen_q: self.max_length_q, AttentionKwargs.max_seqlen_k: self.max_length_k, + AttentionKwargs.min_seqlen_q: self.min_length_q, + AttentionKwargs.min_seqlen_k: self.min_length_k, AttentionKwargs.document_index_q: self.document_index_q, AttentionKwargs.document_index_k: self.document_index_k, LanguageModelKwargs.position_ids: self.position_index, @@ -101,6 +105,8 @@ def preprocess(self, model_input: BlockModelInput, config: LengthPreprocessingCo model_input.cumulative_lengths_q, model_input.cumulative_lengths_k = self.cumulative_lengths if config.return_max_sequence_lengths or config.return_document_index: model_input.max_length_q, model_input.max_length_k = self.max_lengths + if config.return_min_sequence_lengths: + model_input.min_length_q, model_input.min_length_k = self.min_lengths if config.return_document_index: model_input.document_index_q, model_input.document_index_k = self.document_index if config.return_position_index: @@ -118,13 +124,18 @@ def cumulative_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: return cumulative_lengths_q, cumulative_lengths_k @functools.cached_property - def max_lengths(self) -> tuple[torch.Tensor, torch.Tensor]: + def max_lengths(self) -> tuple[int, int]: max_length_q = max(self.lengths) max_length_k = max(max_length_q, self.sequence_k_past + self.lengths[0] - self.first_document_begin) - return ( - torch.full((1,), max_length_q, dtype=torch.int32, device=self.device), - torch.full((1,), max_length_k, dtype=torch.int32, device=self.device), - ) + return max_length_q, max_length_k + + @functools.cached_property + def min_lengths(self) -> tuple[int, int]: + min_length_q = min(self.lengths) + # First doc's K-side length includes the past KV prefix; remaining docs match q-side. + first_length_k = self.sequence_k_past + self.lengths[0] - self.first_document_begin + min_length_k = min(first_length_k, *self.lengths[1:]) if len(self.lengths) > 1 else first_length_k + return min_length_q, min_length_k @functools.cached_property def document_index(self) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 352311b51..a90bcdebc 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -25,6 +25,7 @@ class LengthPreprocessingConfig(BatchPreprocessingConfig): distributed: DistributedConfig = Field() return_cumulative_sequence_lengths: bool = Field(default=False) return_max_sequence_lengths: bool = Field(default=False) + return_min_sequence_lengths: bool = Field(default=False) return_document_index: bool = Field(default=False) return_position_index: bool = Field(default=False) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 57b826ab2..9ff1f7846 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -280,14 +280,29 @@ def _attn_sdpa( } if query.is_cuda and self._config.window_size is None: # Wrap each document as its own batch element via nested-jagged so cross-doc masking - # is structural and EFFICIENT skips materializing the attention mask. SDPA's nested - # dispatch reads `max_seqlen`/`min_seqlen` to host (5 cudaMemcpyAsync DtoH per call), - # which floors per-call wall at ~6 ms; still much faster than dense+mask in varlen. + # is structural and EFFICIENT skips materializing the attention mask. The dispatch + # otherwise reads `max_seqlen`/`min_seqlen` to host on every call; passing them in + # explicitly keeps the path sync-free. cu_seqlens_q = kwargs[AttentionKwargs.cu_seqlens_q].to(torch.int64) cu_seqlens_k = kwargs[AttentionKwargs.cu_seqlens_k].to(torch.int64) - query = torch.nested.nested_tensor_from_jagged(query, cu_seqlens_q) - key = torch.nested.nested_tensor_from_jagged(key, cu_seqlens_k) - value = torch.nested.nested_tensor_from_jagged(value, cu_seqlens_k) + query = torch.nested.nested_tensor_from_jagged( + query, + cu_seqlens_q, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_q], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_q], + ) + key = torch.nested.nested_tensor_from_jagged( + key, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) + value = torch.nested.nested_tensor_from_jagged( + value, + cu_seqlens_k, + min_seqlen=kwargs[AttentionKwargs.min_seqlen_k], + max_seqlen=kwargs[AttentionKwargs.max_seqlen_k], + ) sdpa_args["is_causal"] = self._config.causal else: # Dense + backup's preprocessed causal+document mask. Required on CPU (MATH rejects @@ -565,7 +580,12 @@ def get_preprocessing_config(self) -> dict[str, typing.Any]: and self._distributed_config.use_cuda and self._config.window_size is None ): - return {"return_cumulative_sequence_lengths": True, "causal": self._config.causal} + return { + "return_cumulative_sequence_lengths": True, + "return_max_sequence_lengths": True, + "return_min_sequence_lengths": True, + "causal": self._config.causal, + } elif self._implementation in (AttentionImplementation.sdpa, AttentionImplementation.backup): return {"return_document_index": True, "causal": self._config.causal} else: diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 69aa4f484..f69e2129d 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -21,6 +21,8 @@ class MixerKwargs(BlockKwargs): cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" + min_seqlen_q = "min_seqlen_q" + min_seqlen_k = "min_seqlen_k" document_index_q = "document_index_q" document_index_k = "document_index_k" position_ids = "position_ids"