-
Notifications
You must be signed in to change notification settings - Fork 42
:feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference #188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
maxiaosong1124
wants to merge
13
commits into
RL-Align:main
Choose a base branch
from
maxiaosong1124:feat/ws1-attention-pytorch-op
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
444eb3b
feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference
maxiaosong1124 39df0d2
Merge branch 'RL-Align:main' into feat/ws1-attention-pytorch-op
maxiaosong1124 256ec1d
fix(ws1): address CodeRabbit review on attention op
maxiaosong1124 73a273b
fix(ws1): run attention tests in CI, relax padding bitwise claim to a…
maxiaosong1124 c194db1
Merge branch 'main' into feat/ws1-attention-pytorch-op
maxiaosong1124 bf6e055
Merge branch 'RL-Align:main' into feat/ws1-attention-pytorch-op
maxiaosong1124 1e3a990
fix(ws1): bump padding atol to 2e-6 and clarify Axis-A doc per review
maxiaosong1124 25d7377
Merge remote-tracking branch 'upstream/main' into feat/ws1-attention-…
maxiaosong1124 674e83a
test(ws1): address PR #188 review — gradient vs fp64 reference (KJLde…
maxiaosong1124 91c6b33
docs(ws1): sync attention.md contract with impl per PR #195 review
maxiaosong1124 22d45e6
test(ws1): add backward batch-invariance slice test (Flink-ddd)
maxiaosong1124 81eeb69
Merge branch 'main' into feat/ws1-attention-pytorch-op
maxiaosong1124 b022348
Merge branch 'main' into feat/ws1-attention-pytorch-op
maxiaosong1124 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # Standard Softmax Attention | ||
|
|
||
| The attention operator is the reduction core of the Qwen3/Llama transformer block. It is a | ||
| **WS1 ground-truth reference** (issue #108): a pure-PyTorch, fp32-accumulating definition of | ||
| the "correct answer" that downstream fused CUDA/Triton attention kernels (FlashAttention and | ||
| friends) are validated against. | ||
|
|
||
| - **`NativeAttentionOp`**: `out = softmax(Q Kᵀ · scale + masks) @ V` — a hand-written naive | ||
| softmax with a **fixed reduction order**, deliberately **not** | ||
| `F.scaled_dot_product_attention` / flash / mem-efficient attention (whose reduction order | ||
| is unspecified and would break the batch-invariance contract). | ||
|
|
||
| This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before* | ||
| the call (see the chain), so the `q`, `k` passed in are already normalized and rotated. | ||
|
|
||
| ```text | ||
| q --\ | ||
| k ----softmax(QKᵀ/√d + mask)·V--> out | ||
| v --/ | ||
| ``` | ||
|
|
||
| ## Entry Point | ||
| ```python | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| attn = kernel_registry.get_op("attention") | ||
|
|
||
| # Prefill: Sq == Skv ; Decode: Sq < Skv (one/few new queries against the full cache) | ||
| out = attn(q, k, v, causal=True) # [B, 32, Sq, 128] | ||
| out = attn(q, k, v, causal=True, scale=1.0 / 128 ** 0.5) # explicit scale | ||
| out = attn(q, k, v, causal=False, key_padding_mask=mask) # mask: [B, Skv] bool, True = keep | ||
| ``` | ||
|
|
||
| The op exposes the WS1 dual-path contract: | ||
|
|
||
| - `forward(...)` — computes in the input dtype, returns the input dtype (Axis-B accuracy | ||
| candidate / dtype-behavior path). | ||
| - `forward_fp32(...)` — upcasts to fp32, accumulates in fp32, returns fp32 (the ground-truth | ||
| golden path). It disables autocast and TF32 so it stays a true fp32 reference regardless of | ||
| the caller's ambient precision context. | ||
|
|
||
| > **Not the same as `"attn"`.** `kernel_registry.get_op("attn")` resolves to the production | ||
| > SDPA fallback (`PYTORCH_ATTN`); this ground-truth op is registered separately under | ||
| > `"attention"` (`PYTORCH_NATIVE_ATTENTION`). The two do not overlap. | ||
|
|
||
| ## Backends | ||
|
|
||
| | Backend | Wrapper | Native symbol | Status | | ||
| | --- | --- | --- | --- | | ||
| | PyTorch fallback | `NativeAttentionOp` | None | fp32 ground-truth reference; CPU and any GPU. | | ||
| | CUDA / ROCm / Triton | — | — | Planned: downstream fused attention kernels validate against this reference. | | ||
|
|
||
| ## Tensor Contract | ||
|
|
||
| | Argument | Shape | Dtype | Requirements | | ||
| | --- | --- | --- | --- | | ||
| | `q` | `[B, Hq, Sq, D]` | float (fp16/bf16/fp32) | Qwen3-8B: `Hq=32`, `D=128`. | | ||
| | `k` | `[B, Hkv, Skv, D]` | float | Qwen3-8B: `Hkv=8` (GQA). `Hq` must be divisible by `Hkv`. | | ||
| | `v` | `[B, Hkv, Skv, D]` | float | Same head/seq layout as `k`. | | ||
| | `causal` | — | bool (kw, default `True`) | Upper-triangular mask at offset `Skv - Sq + 1`. | | ||
| | `scale` | — | float or `None` (kw) | `None` → `1/sqrt(D)` = `1/√128`. An explicit value (incl. `0.0`) is used verbatim. | | ||
| | `key_padding_mask` | `[B, Skv]` | bool or `None` (kw) | `True` = valid / keep, `False` = padding → that key column set to `-inf`. | | ||
| | output | `[B, Hq, Sq, D]` | `forward`: input dtype · `forward_fp32`: float32 | Heads precede seq (`[B, H, S, D]`). | | ||
|
|
||
| **GQA** (`Hq=32`, `Hkv=8`, group `g=4`): each KV head is replicated `g` times with | ||
| `repeat_interleave(g, dim=1)` (not `repeat`), so query head `h` maps to KV head `h // g`. | ||
|
|
||
| **Causal offset** `Skv - Sq + 1` anchors the queries to the end of the sequence, so a single | ||
| expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees | ||
| the whole cache). | ||
|
|
||
| Pure function — no randomness, no in-place mutation; device follows the inputs. `forward(...)` | ||
| preserves the input dtype, while `forward_fp32(...)` always returns fp32. Masks are built on | ||
| the inputs' device. | ||
|
|
||
| ## Dispatch Behavior | ||
|
|
||
| `kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On | ||
| `cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op | ||
| (`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to this op. Calling it (`__call__` -> | ||
| `forward(...)`) computes in the input dtype; `forward_fp32(...)` is the explicit fp32 golden | ||
| path. When fused attention kernels land, they are prepended to the priority list and the native | ||
| op becomes the fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is | ||
| a separate dispatch chain and is unaffected. | ||
|
|
||
| ## Accuracy | ||
|
|
||
| Reference semantics (`forward_fp32`, fp32 accumulation, TF32/autocast disabled): | ||
|
|
||
| ```python | ||
| qf, kf, vf = q.float(), k.float(), v.float() | ||
| if Hkv != Hq: # GQA: replicate KV, 32 Q / 8 KV, r = 4 | ||
| r = Hq // Hkv | ||
| kf = kf.repeat_interleave(r, dim=1) | ||
| vf = vf.repeat_interleave(r, dim=1) | ||
| scale = scale if scale is not None else 1.0 / math.sqrt(D) # D=128 → 1/√128 | ||
| scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] | ||
| if causal: # offset covers prefill + decode | ||
| m = torch.triu(torch.ones(Sq, Skv, dtype=torch.bool), diagonal=Skv - Sq + 1) | ||
| scores = scores.masked_fill(m, float("-inf")) | ||
| if key_padding_mask is not None: # True = keep ; False columns → -inf | ||
| scores = scores.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) | ||
| probs = torch.softmax(scores, dim=-1) # subtracts per-row max internally | ||
| out = torch.matmul(probs, vf) # [B, Hq, Sq, D] | ||
| ``` | ||
|
|
||
| - **Ground truth**: `forward_fp32` always accumulates in and returns fp32, with TF32 and | ||
| autocast disabled so it is not silently downcast by the caller's ambient context. | ||
| - **Dtype path**: `forward` runs the same math in the input dtype, so low-precision reductions | ||
| over the key dimension drift from the fp32 reference — Axis-B accuracy therefore uses a | ||
| tolerance, not bitwise equality. | ||
| - **Axis A — batch invariance**: each query row reduces over the keys independently of how many | ||
| sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`) | ||
| across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax | ||
| reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width | ||
| (e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small | ||
| tolerance (`atol=2e-6`), not bitwise, in IEEE 754. | ||
| - **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction` | ||
| row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak): | ||
|
|
||
| | dtype | max_abs / peak | threshold (rel-peak) | | ||
| | --- | --- | --- | | ||
| | bfloat16 | ~0.56 % | 3 % | | ||
| | float16 | ~0.07 % | 0.5 % | | ||
|
|
||
| ## Performance Notes | ||
|
|
||
| Reference operator — no fused kernel or benchmark yet. Downstream fused attention kernels carry | ||
| their own benchmarks and are measured against this reference for correctness. At the LARGE | ||
| Qwen3-8B load point (`B=8`, `Skv=4096`, `Hq=32`) the fp32 scores tensor alone is ~17 GB and the | ||
| naive path peaks at ~3× that, so the LARGE smoke test is GPU-only and skips without enough | ||
| memory. | ||
|
|
||
| ## Tests | ||
|
|
||
| ```bash | ||
| python -m pytest tests/test_attention.py -v | ||
| ``` | ||
|
|
||
| Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile | ||
| autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard, | ||
| scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch | ||
| invariance (slice + chunked, bitwise; padding is near-equality only, see below), input purity, | ||
| gradient flow, registry dispatch, and a | ||
| GPU-only LARGE Qwen3-8B real-shape smoke test. | ||
|
|
||
| ## Implementation Files | ||
|
|
||
| - `rl_engine/kernels/ops/pytorch/attention/standard_attn.py` | ||
| - `rl_engine/kernels/registry.py` | ||
| - `tests/test_attention.py` | ||
|
|
||
| ## Known Limitations | ||
|
|
||
| - PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). | ||
| - `Hq` must be divisible by `Hkv` (raises `ValueError` otherwise). | ||
| - The naive path materializes the full `[B, Hq, Sq, Skv]` scores tensor — no query-chunking, | ||
| so the LARGE load point is memory-heavy and GPU-only. | ||
| - Covers softmax attention only; QK-Norm and RoPE are applied before the call. |
191 changes: 191 additions & 0 deletions
191
rl_engine/kernels/ops/pytorch/attention/standard_attn.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import math | ||
| from contextlib import contextmanager, nullcontext | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class NativeAttentionOp: | ||
| """ | ||
| Pure PyTorch native standard-softmax attention reference. | ||
| out = softmax(Q Kᵀ * scale + masks) @ V | ||
|
|
||
| Hand-written naive softmax -- deliberately NOT | ||
| ``F.scaled_dot_product_attention`` / flash / mem-efficient attention, whose | ||
| reduction order is unspecified and would break the batch-invariance (Axis-A) | ||
| contract. This op defines the *correct answer* the fused kernels align to. | ||
|
|
||
| Qwen3-8B shapes: q ``[B, 32, Sq, 128]``, k/v ``[B, 8, Skv, 128]`` (GQA group | ||
| g = 32/8 = 4), scale = 1/sqrt(head_dim) = 1/sqrt(128). Heads precede seq in | ||
| the layout. This op covers ONLY the softmax attention; QK-Norm and RoPE are | ||
| applied *before* the call (see the chain test) -- the q,k passed in are | ||
| already normalized and rotated. | ||
|
|
||
| This is a reduction over the key dimension (Skv): the low-precision | ||
| ``forward`` path accumulates in the input dtype and therefore drifts from | ||
| the fp32 ``forward_fp32`` ground truth, so Axis-B accuracy uses a tolerance | ||
| (``torch.allclose``), not bitwise equality. Axis-A batch invariance still | ||
| holds bitwise within a single dtype (each query row reduces over the keys | ||
| independently of how many sequences share the batch). | ||
|
|
||
| Masking conventions: | ||
| * causal=True -> upper-triangular -inf at diagonal Skv-Sq+1, valid for | ||
| both prefill (Sq==Skv) and decode (Sq<Skv). | ||
| * key_padding_mask ``[B, Skv]`` bool, True=valid / False=padding -> padded | ||
| key columns set to -inf (matches reference_ops.py: True=keep, False=mask). | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| """No state; the op is a pure function over (q, k, v, ...).""" | ||
|
|
||
| def __call__( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| *, | ||
| causal: bool = True, | ||
| scale: Optional[float] = None, | ||
| key_padding_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Alias for ``forward`` so the op is callable like a module.""" | ||
| return self.forward(q, k, v, causal=causal, scale=scale, key_padding_mask=key_padding_mask) | ||
|
|
||
| def forward( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| *, | ||
| causal: bool = True, | ||
| scale: Optional[float] = None, | ||
| key_padding_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Canonical entry: attend in the input dtype, output the input dtype. | ||
| This is the dtype-behavior path used as the Axis-B accuracy candidate. | ||
| """ | ||
| return self._attention( | ||
| q, | ||
| k, | ||
| v, | ||
| causal=causal, | ||
| scale=scale, | ||
| key_padding_mask=key_padding_mask, | ||
| compute_dtype=q.dtype, | ||
| output_dtype=q.dtype, | ||
| ) | ||
|
|
||
| def forward_fp32( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| *, | ||
| causal: bool = True, | ||
| scale: Optional[float] = None, | ||
| key_padding_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Ground truth: upcast to fp32, accumulate in fp32, force fp32 output. | ||
|
|
||
| The whole score->softmax->value path is wrapped to disable autocast and | ||
| TF32 so this stays a true fp32 reference regardless of the caller's | ||
| ambient precision context. | ||
| """ | ||
| return self._attention( | ||
| q, | ||
| k, | ||
| v, | ||
| causal=causal, | ||
| scale=scale, | ||
| key_padding_mask=key_padding_mask, | ||
| compute_dtype=torch.float32, | ||
| output_dtype=torch.float32, | ||
| strict_fp32=True, | ||
| ) | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Helpers | ||
| # ------------------------------------------------------------------ # | ||
| @staticmethod | ||
| def _attention( | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| *, | ||
| causal: bool, | ||
| scale: Optional[float], | ||
| key_padding_mask: Optional[torch.Tensor], | ||
| compute_dtype: torch.dtype, | ||
| output_dtype: torch.dtype, | ||
| strict_fp32: bool = False, | ||
| ) -> torch.Tensor: | ||
| """Core softmax attention: cast to ``compute_dtype``, score, mask, softmax, | ||
| weighted-sum over V, cast out. ``strict_fp32`` disables autocast/TF32 so the | ||
| fp32 reference is not silently downcast by the caller's ambient context. | ||
| """ | ||
| Hq, Sq, D = q.shape[1], q.shape[2], q.shape[3] | ||
| Hkv, Skv = k.shape[1], k.shape[2] | ||
| if Hq % Hkv != 0: | ||
| raise ValueError(f"Hq={Hq} not divisible by Hkv={Hkv} (GQA group)") | ||
|
|
||
| ctx = NativeAttentionOp._strict_fp32_math(q.device.type) if strict_fp32 else nullcontext() | ||
| with ctx: | ||
| qf = q.to(compute_dtype) | ||
| kf = k.to(compute_dtype) | ||
| vf = v.to(compute_dtype) | ||
|
|
||
| # GQA: replicate each KV head g=Hq//Hkv times (Qwen3: 32/8 -> 4). | ||
| # repeat_interleave (not repeat) keeps each KV head's copies adjacent | ||
| # so query head h maps to KV head h // g. | ||
| if Hkv != Hq: | ||
| r = Hq // Hkv | ||
| kf = kf.repeat_interleave(r, dim=1) | ||
| vf = vf.repeat_interleave(r, dim=1) | ||
|
|
||
| # scale defaults to 1/sqrt(head_dim); `is not None` so an explicit 0.0 is kept. | ||
| scale = scale if scale is not None else (1.0 / math.sqrt(D)) | ||
| scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] | ||
|
|
||
| # Causal: offset Skv-Sq+1 covers prefill (Sq==Skv) and decode (Sq<Skv). | ||
| if causal: | ||
| causal_mask = torch.triu( | ||
| torch.ones(Sq, Skv, dtype=torch.bool, device=q.device), | ||
| diagonal=Skv - Sq + 1, | ||
| ) | ||
| scores = scores.masked_fill(causal_mask, float("-inf")) | ||
|
|
||
| # key_padding_mask [B, Skv]: True=valid; False columns -> -inf. | ||
| if key_padding_mask is not None: | ||
| pad = ~key_padding_mask | ||
| scores = scores.masked_fill(pad[:, None, None, :], float("-inf")) | ||
|
|
||
| probs = torch.softmax(scores, dim=-1) # subtracts row max internally | ||
| if key_padding_mask is not None: | ||
| # A query whose every valid key is padded out has an all -inf row; | ||
| # softmax would emit NaN. Define such fully-masked rows as 0 (no key | ||
| # contributes), keeping outputs/grads finite. Only key_padding_mask can | ||
| # fully mask a row -- causal alone always leaves a query its own key -- | ||
| # so this guard lives in the padding branch and the no-pad path is | ||
| # untouched. Row-independent, so Axis-A batch invariance still holds. | ||
| all_masked = ~torch.isfinite(scores).any(dim=-1, keepdim=True) | ||
| probs = torch.where(all_masked, torch.zeros_like(probs), probs) | ||
| out = torch.matmul(probs, vf) # [B, Hq, Sq, D] | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| return out.to(output_dtype) | ||
|
|
||
| @staticmethod | ||
| @contextmanager | ||
| def _strict_fp32_math(device_type: str): | ||
| """Disable autocast and TF32 for a true fp32 path, restoring state after.""" | ||
| prev_tf32 = torch.backends.cuda.matmul.allow_tf32 | ||
| torch.backends.cuda.matmul.allow_tf32 = False | ||
| try: | ||
| with torch.autocast(device_type=device_type, enabled=False): | ||
| yield | ||
| finally: | ||
| torch.backends.cuda.matmul.allow_tf32 = prev_tf32 | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.