Add SDPA attention implementation#512
Open
jlamypoirier wants to merge 6 commits intomainfrom
Open
Conversation
Flash-attn errors out at head_size > 256, so head_size=512 models cannot train without materializing the full O(S²) attention matrix via the backup path. Add `AttentionImplementation.sdpa` using `torch.nested` to bridge the packed-varlen layout to SDPA's batched signature, pinning the EFFICIENT backend. K/V are manually repeat_interleaved to match Q heads because the fused kernels reject broadcasted GQA inputs. Auto-fallback: flash when bf16/fp16 + head_size <= 256 + flash is available; backup for windowed attention (the sdpa path does not support sliding window); sdpa otherwise. Tests: SDPA equivalence check parallel to flash, gated on CUDA + bf16; two head_size=320 cases exercising the SDPA-only regime; refactored parametrization from `_build_test_cases` plus single-use variant lists into a few inline for-loops at module level. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The SDPA path uses `nested_tensor_from_jagged + is_causal=True` which has no viable backend on CPU (math rejects nested + is_causal; the fused EFFICIENT/Flash backends are CUDA-only). Auto previously routed CPU runs through SDPA and they would crash; route them to backup. Also widens the SDPA branch to fp32 explicitly: the EFFICIENT backend engages on CUDA across bf16/fp16/fp32, and benchmarking confirms it beats backup on memory at every length and matches it on time at seq_len >= 4096 (backup grows quadratically; SDPA stays near constant). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous attempt routed CPU and windowed configurations to backup because the nested + is_causal=True form has no viable backend on CPU and cannot express sliding window. SDPA actually works fine in those cases when given an explicit attn_mask: backup's preprocessing already builds the combined causal+document mask (and threads sliding window into it), so the SDPA path can reuse it as-is. CUDA without a window keeps the nested + is_causal path so EFFICIENT runs without materializing the mask. CUDA with a window and CPU runs both fall through to dense + attn_mask, which lets MATH engage on CPU and reuses the windowed mask on CUDA. Auto-fallback simplifies to flash-or-sdpa: SDPA now covers every case backup used to (CPU, windowed without flash, head_size > 256). Verified on H100 bf16 head_size=512 that the dense + attn_mask form also engages EFFICIENT (peak 323 MiB vs 319 MiB for is_causal — the 4 MiB delta is the mask itself). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…on call The CUDA-no-window and dense-mask paths shared the K/V expansion, the SDPA call signature (dropout + scale), and the (B, H, S, D) layout requirement. Lift those out: rebind query/key/value to either nested-jagged or unsqueeze(0)'d 4D tensors in the per-path setup, build an `sdpa_args` dict that adds `is_causal=...` for nested or `attn_mask=...` for dense, then make a single SDPA call that works for both. The unwrap branches on `output.is_nested`. Also drops the explicit EFFICIENT_ATTENTION pin from the nested path — nested + is_causal=True has no other viable backend (MATH and Flash both reject it), so the auto pick lands on EFFICIENT or the call errors out either way. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The nested path floors per-call wall around 6 ms because SDPA's nested dispatch pulls `max_seqlen` / `min_seqlen` to host (5 cudaMemcpyAsync DtoH + cudaStreamSynchronize per call). Sync count is fixed regardless of num_docs, so the path stays much faster than dense+mask in varlen training; the comment just makes the cost discoverable. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
PyTorch's nested SDPA dispatch reads `max_seqlen` and `min_seqlen` to host on every call (5 cudaMemcpyAsync DtoH + cudaStreamSynchronize per call) when they aren't supplied. Both are trivially derivable from the Python `lengths` list at preprocessing time, so we compute them as plain ints, thread them through `BlockModelInput` / kwargs, and pass them to `nested_tensor_from_jagged`. While doing this, drop the `torch.full((1,), ..., device=...)` wrap on `max_lengths` — the value was always a Python int, and flash accepts an int directly (verified). The auto-device-move on the `Document` base class only moves Tensor fields, so plain ints pass through to_kwargs untouched. Sync events per call (Llama-7B-shape, 4 docs × 4096): before: 5 cudaStreamSynchronize + 5 Memcpy DtoH after: 0 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Flash-attn caps at
head_size = 256;head_size = 512models (e.g. Gemma 4's full-attention layers) currently force thebackuppath, which materializes the full O(S²) attention matrix and OOMs above ~8K context on H100. AddAttentionImplementation.sdpaso those models can train.The implementation has two CUDA-aware paths sharing the rest of the layer:
torch.nested.nested_tensor_from_jagged(values, cu_seqlens, min_seqlen=..., max_seqlen=...)+is_causal=Trueunder EFFICIENT. Each document becomes its own batch element, so cross-document attention is excluded by structure rather than by mask. Pre-computedmin_seqlen/max_seqlenare passed in to keep the dispatch sync-free (otherwise it reads them to host on every call — 5 host barriers per layer).(1, H, total, D)+attn_mask, reusing backup's preprocessed causal+document mask. MATH cannot accept nested +is_causal=True, so the mask path is the only viable form on CPU; on CUDA-with-window the mask is needed because nested +is_causalcannot express sliding window. Per a cluster probe, EFFICIENT also engages on CUDA with explicitattn_mask— only ~4 MiB extra overis_causalfor the mask itself.K/V are manually
repeat_interleaved across query heads in both paths because SDPA's fused kernels reject broadcasted GQA inputs.Auto-fallback simplifies to
flashforbf16/fp16+head_size ≤ 256+ flash available, otherwisesdpa. SDPA now covers every previously-backupcase (CPU, windowed without flash, head_size > 256);backupremains as an explicitimplementation: backupoption but the auto path no longer reaches it.To pre-compute the seqlens for SDPA, the data preprocessor's
max_lengthsis changed from a 1-element device tensor to a Pythonint(flash acceptsintnatively, verified),min_lengthsis added symmetrically, and areturn_min_sequence_lengthsflag is added toLengthPreprocessingConfig. Plain ints sail throughDocument.to_device_since it only movesTensorfields.get_preprocessing_configbranches by impl: flash needs cu_seqlens + max_seqlens; sdpa-CUDA-no-window needs cu_seqlens + max + min; sdpa-windowed / sdpa-CPU / backup all need document_index (mask is built inpreprocessand shared).Tests: SDPA equivalence check parallel to flash via a small
_check_packedclosure (CUDA bf16); twohead_size=320cases that exercise the SDPA-only regime; windowed cases now exercise SDPA too. Parametrization refactored from_build_test_cases+ single-use variant lists into inline for-loops at module level.Benchmark — H100 bf16, 20 iters after 10 warmup, fwd+bwd wall
Llama-7B-shape (32 heads MHA, head_size=128):
Gemma-4 full-attn (16/8 GQA, head_size=512):
Multi-document varlen — the typical training case — is where nested+is_causal pulls ahead of mask by 2.6×–7×: nested processes each doc as its own batch element (4×4K² attention work) while mask materializes the full 16K² matrix even though same-doc cross-attention is then masked out. Backup OOMs above ~8K at these widths.
Sync events per nested SDPA call (profiled): 0 with pre-computed seqlens; 5 without. Pure wall-clock impact in synthetic bench is ~0.1–1 ms/call, but in a real training loop those 5 host barriers per layer × 30 layers × 8 microbatches = 1200 syncs/step would have prevented host-GPU overlap; with them gone, the nested path's ~6 ms of Python wrapping overhead can hide behind GPU compute.
Test plan
pytest -v -n 4 tests/layers/test_attention.py(CPU): 56 passedpytest -v -n 8 tests/layers/test_attention.py(CUDA): 56 passed; all SDPA equivalence checks run, including windowed cases🤖 Generated with Claude Code