-
Notifications
You must be signed in to change notification settings - Fork 42
feat(ws1): add NativeKVCacheAttnOp pure-PyTorch KV-cache attention reference #195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
maxiaosong1124
wants to merge
13
commits into
RL-Align:main
Choose a base branch
from
maxiaosong1124:feat/ws1-kv-cache-attention-pytorch-op
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
444eb3b
feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference
maxiaosong1124 39df0d2
Merge branch 'RL-Align:main' into feat/ws1-attention-pytorch-op
maxiaosong1124 256ec1d
fix(ws1): address CodeRabbit review on attention op
maxiaosong1124 73a273b
fix(ws1): run attention tests in CI, relax padding bitwise claim to a…
maxiaosong1124 c194db1
Merge branch 'main' into feat/ws1-attention-pytorch-op
maxiaosong1124 bf6e055
Merge branch 'RL-Align:main' into feat/ws1-attention-pytorch-op
maxiaosong1124 1e3a990
fix(ws1): bump padding atol to 2e-6 and clarify Axis-A doc per review
maxiaosong1124 25d7377
Merge remote-tracking branch 'upstream/main' into feat/ws1-attention-…
maxiaosong1124 674e83a
test(ws1): address PR #188 review — gradient vs fp64 reference (KJLde…
maxiaosong1124 91c6b33
docs(ws1): sync attention.md contract with impl per PR #195 review
maxiaosong1124 4e6df2c
feat(ws1): add NativeKVCacheAttnOp pure-PyTorch KV-cache attention re…
maxiaosong1124 4a23b5f
test+fix(ws1): address PR #195 review — Sq==S_new guard + contract-al…
maxiaosong1124 c581060
Merge branch 'main' into feat/ws1-kv-cache-attention-pytorch-op
maxiaosong1124 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| # Standard Softmax Attention | ||
|
|
||
| The attention operator is the reduction core of the Qwen3/Llama transformer block. It is a | ||
| **WS1 ground-truth reference** (issue #108): a pure-PyTorch, fp32-accumulating definition of | ||
| the "correct answer" that downstream fused CUDA/Triton attention kernels (FlashAttention and | ||
| friends) are validated against. | ||
|
|
||
| - **`NativeAttentionOp`**: `out = softmax(Q Kᵀ · scale + masks) @ V` — a hand-written naive | ||
| softmax with a **fixed reduction order**, deliberately **not** | ||
| `F.scaled_dot_product_attention` / flash / mem-efficient attention (whose reduction order | ||
| is unspecified and would break the batch-invariance contract). | ||
|
|
||
| This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before* | ||
| the call (see the chain), so the `q`, `k` passed in are already normalized and rotated. | ||
|
|
||
| ```text | ||
| q --\ | ||
| k ----softmax(QKᵀ/√d + mask)·V--> out | ||
| v --/ | ||
| ``` | ||
|
|
||
| ## Entry Point | ||
| ```python | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| attn = kernel_registry.get_op("attention") | ||
|
|
||
| # Prefill: Sq == Skv ; Decode: Sq < Skv (one/few new queries against the full cache) | ||
| out = attn(q, k, v, causal=True) # [B, 32, Sq, 128] | ||
| out = attn(q, k, v, causal=True, scale=1.0 / 128 ** 0.5) # explicit scale | ||
| out = attn(q, k, v, causal=False, key_padding_mask=mask) # mask: [B, Skv] bool, True = keep | ||
| ``` | ||
|
|
||
| The op exposes the WS1 dual-path contract: | ||
|
|
||
| - `forward(...)` — computes in the input dtype, returns the input dtype (Axis-B accuracy | ||
| candidate / dtype-behavior path). | ||
| - `forward_fp32(...)` — upcasts to fp32, accumulates in fp32, returns fp32 (the ground-truth | ||
| golden path). It disables autocast and TF32 so it stays a true fp32 reference regardless of | ||
| the caller's ambient precision context. | ||
|
|
||
| > **Not the same as `"attn"`.** `kernel_registry.get_op("attn")` resolves to the production | ||
| > SDPA fallback (`PYTORCH_ATTN`); this ground-truth op is registered separately under | ||
| > `"attention"` (`PYTORCH_NATIVE_ATTENTION`). The two do not overlap. | ||
|
|
||
| ## Backends | ||
|
|
||
| | Backend | Wrapper | Native symbol | Status | | ||
| | --- | --- | --- | --- | | ||
| | PyTorch fallback | `NativeAttentionOp` | None | fp32 ground-truth reference; CPU and any GPU. | | ||
| | CUDA / ROCm / Triton | — | — | Planned: downstream fused attention kernels validate against this reference. | | ||
|
|
||
| ## Tensor Contract | ||
|
|
||
| | Argument | Shape | Dtype | Requirements | | ||
| | --- | --- | --- | --- | | ||
| | `q` | `[B, Hq, Sq, D]` | float (fp16/bf16/fp32) | Qwen3-8B: `Hq=32`, `D=128`. | | ||
| | `k` | `[B, Hkv, Skv, D]` | float | Qwen3-8B: `Hkv=8` (GQA). `Hq` must be divisible by `Hkv`. | | ||
| | `v` | `[B, Hkv, Skv, D]` | float | Same head/seq layout as `k`. | | ||
| | `causal` | — | bool (kw, default `True`) | Upper-triangular mask at offset `Skv - Sq + 1`. | | ||
| | `scale` | — | float or `None` (kw) | `None` → `1/sqrt(D)` = `1/√128`. An explicit value (incl. `0.0`) is used verbatim. | | ||
| | `key_padding_mask` | `[B, Skv]` | bool or `None` (kw) | `True` = valid / keep, `False` = padding → that key column set to `-inf`. | | ||
| | output | `[B, Hq, Sq, D]` | `forward`: input dtype · `forward_fp32`: float32 | Heads precede seq (`[B, H, S, D]`). | | ||
|
|
||
| **GQA** (`Hq=32`, `Hkv=8`, group `g=4`): each KV head is replicated `g` times with | ||
| `repeat_interleave(g, dim=1)` (not `repeat`), so query head `h` maps to KV head `h // g`. | ||
|
|
||
| **Causal offset** `Skv - Sq + 1` anchors the queries to the end of the sequence, so a single | ||
| expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees | ||
| the whole cache). | ||
|
|
||
| Pure function — no randomness, no in-place mutation; device follows the inputs. `forward(...)` | ||
| preserves the input dtype, while `forward_fp32(...)` always returns fp32. Masks are built on | ||
| the inputs' device. | ||
|
|
||
| ## Dispatch Behavior | ||
|
|
||
| `kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On | ||
| `cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op | ||
| (`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to this op. Calling it (`__call__` -> | ||
| `forward(...)`) computes in the input dtype; `forward_fp32(...)` is the explicit fp32 golden | ||
| path. When fused attention kernels land, they are prepended to the priority list and the native | ||
| op becomes the fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is | ||
| a separate dispatch chain and is unaffected. | ||
|
|
||
| ## Accuracy | ||
|
|
||
| Reference semantics (`forward_fp32`, fp32 accumulation, TF32/autocast disabled): | ||
|
|
||
| ```python | ||
| qf, kf, vf = q.float(), k.float(), v.float() | ||
| if Hkv != Hq: # GQA: replicate KV, 32 Q / 8 KV, r = 4 | ||
| r = Hq // Hkv | ||
| kf = kf.repeat_interleave(r, dim=1) | ||
| vf = vf.repeat_interleave(r, dim=1) | ||
| scale = scale if scale is not None else 1.0 / math.sqrt(D) # D=128 → 1/√128 | ||
| scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] | ||
| if causal: # offset covers prefill + decode | ||
| m = torch.triu(torch.ones(Sq, Skv, dtype=torch.bool), diagonal=Skv - Sq + 1) | ||
| scores = scores.masked_fill(m, float("-inf")) | ||
| if key_padding_mask is not None: # True = keep ; False columns → -inf | ||
| scores = scores.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) | ||
| probs = torch.softmax(scores, dim=-1) # subtracts per-row max internally | ||
| out = torch.matmul(probs, vf) # [B, Hq, Sq, D] | ||
| ``` | ||
|
|
||
| - **Ground truth**: `forward_fp32` always accumulates in and returns fp32, with TF32 and | ||
| autocast disabled so it is not silently downcast by the caller's ambient context. | ||
| - **Dtype path**: `forward` runs the same math in the input dtype, so low-precision reductions | ||
| over the key dimension drift from the fp32 reference — Axis-B accuracy therefore uses a | ||
| tolerance, not bitwise equality. | ||
| - **Axis A — batch invariance**: each query row reduces over the keys independently of how many | ||
| sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`) | ||
| across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax | ||
| reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width | ||
| (e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small | ||
| tolerance (`atol=2e-6`), not bitwise, in IEEE 754. | ||
| - **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction` | ||
| row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak): | ||
|
|
||
| | dtype | max_abs / peak | threshold (rel-peak) | | ||
| | --- | --- | --- | | ||
| | bfloat16 | ~0.56 % | 3 % | | ||
| | float16 | ~0.07 % | 0.5 % | | ||
|
|
||
| ## Performance Notes | ||
|
|
||
| Reference operator — no fused kernel or benchmark yet. Downstream fused attention kernels carry | ||
| their own benchmarks and are measured against this reference for correctness. At the LARGE | ||
| Qwen3-8B load point (`B=8`, `Skv=4096`, `Hq=32`) the fp32 scores tensor alone is ~17 GB and the | ||
| naive path peaks at ~3× that, so the LARGE smoke test is GPU-only and skips without enough | ||
| memory. | ||
|
|
||
| ## Tests | ||
|
|
||
| ```bash | ||
| python -m pytest tests/test_attention.py -v | ||
| ``` | ||
|
|
||
| Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile | ||
| autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard, | ||
| scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch | ||
| invariance (slice + chunked, bitwise; padding is near-equality only, see below), input purity, | ||
| gradient flow, registry dispatch, and a | ||
| GPU-only LARGE Qwen3-8B real-shape smoke test. | ||
|
|
||
| ## Implementation Files | ||
|
|
||
| - `rl_engine/kernels/ops/pytorch/attention/standard_attn.py` | ||
| - `rl_engine/kernels/registry.py` | ||
| - `tests/test_attention.py` | ||
|
|
||
| ## Known Limitations | ||
|
|
||
| - PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). | ||
| - `Hq` must be divisible by `Hkv` (raises `ValueError` otherwise). | ||
| - The naive path materializes the full `[B, Hq, Sq, Skv]` scores tensor — no query-chunking, | ||
| so the LARGE load point is memory-heavy and GPU-only. | ||
| - Covers softmax attention only; QK-Norm and RoPE are applied before the call. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright (c) 2026 RL-Kernel Contributors | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
| from rl_engine.kernels.ops.pytorch.attention.standard_attn import NativeAttentionOp | ||
|
|
||
|
|
||
| class NativeKVCacheAttnOp: | ||
| """ | ||
| Pure PyTorch native KV-cache attention reference (ISSUE #108 WS1). | ||
|
|
||
| Decode/incremental attention: the past keys/values live in a cache and the | ||
| step's new keys/values are appended before attending:: | ||
|
|
||
| k_full = cat([k_cache, k_new], dim=2) # along the seq axis | ||
| v_full = cat([v_cache, v_new], dim=2) | ||
| out = NativeAttentionOp().forward_fp32(q, k_full, v_full, ...) | ||
|
|
||
| The whole point is that it delegates to the *same* ``NativeAttentionOp`` used | ||
| for full-sequence (prefill) attention: prefill and decode therefore share one | ||
| reduction path, which is what makes rollout (decode) numerically consistent | ||
| with training (prefill). Re-implementing the softmax here would defeat that. | ||
|
|
||
| Qwen3-8B shapes (synthetic tensors, no checkpoint): q ``[B, 32, Sq, 128]``, | ||
| cache/new k,v ``[B, 8, S_past/S_new, 128]`` (GQA group g = 32/8 = 4). Heads | ||
| precede seq in the layout; the GQA KV replication happens inside | ||
| ``NativeAttentionOp``, not here. This is a reduction over the full key length | ||
| Skv = S_past + S_new. | ||
|
|
||
| Alignment assumption: q's ``Sq`` rows are the *last* Sq positions of the full | ||
| sequence -- i.e. the queries for ``k_new`` -- so callers pass ``Sq == S_new`` | ||
| (decode: Sq == S_new == 1). The causal offset ``Skv - Sq + 1`` inside | ||
| ``NativeAttentionOp`` then lets each new query see the whole cache plus the | ||
| new tokens up to and including itself, for both decode and chunked prefill. | ||
|
|
||
| Masking conventions (forwarded verbatim to ``NativeAttentionOp``): | ||
| * causal=True -> upper-triangular -inf at diagonal Skv-Sq+1. | ||
| * key_padding_mask ``[B, Skv]`` bool over the *concatenated* length | ||
| (S_past + S_new), True=valid / False=padding. | ||
|
|
||
| Only the attention output is returned; producing an updated cache is a | ||
| caller/runtime concern and is not part of this numerical contract. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| """No state; the op is a pure function over (q, k_cache, v_cache, k_new, v_new, ...).""" | ||
| self._attn = NativeAttentionOp() | ||
|
|
||
| def __call__( | ||
| self, | ||
| q: torch.Tensor, | ||
| k_cache: torch.Tensor, | ||
| v_cache: torch.Tensor, | ||
| k_new: torch.Tensor, | ||
| v_new: torch.Tensor, | ||
| *, | ||
| causal: bool = True, | ||
| scale: Optional[float] = None, | ||
| key_padding_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Alias for ``forward`` so the op is callable like a module.""" | ||
| return self.forward( | ||
| q, | ||
| k_cache, | ||
| v_cache, | ||
| k_new, | ||
| v_new, | ||
| causal=causal, | ||
| scale=scale, | ||
| key_padding_mask=key_padding_mask, | ||
| ) | ||
|
|
||
| def forward( | ||
| self, | ||
| q: torch.Tensor, | ||
| k_cache: torch.Tensor, | ||
| v_cache: torch.Tensor, | ||
| k_new: torch.Tensor, | ||
| v_new: torch.Tensor, | ||
| *, | ||
| causal: bool = True, | ||
| scale: Optional[float] = None, | ||
| key_padding_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Canonical entry: concat cache+new, then attend in the input dtype. | ||
| Delegates to ``NativeAttentionOp.forward`` (the Axis-B dtype path). | ||
| """ | ||
| self._validate_decode_alignment(q, k_new, v_new) | ||
| k_full, v_full = self._concat_kv(k_cache, v_cache, k_new, v_new) | ||
| return self._attn.forward( | ||
| q, k_full, v_full, causal=causal, scale=scale, key_padding_mask=key_padding_mask | ||
| ) | ||
|
|
||
| def forward_fp32( | ||
| self, | ||
| q: torch.Tensor, | ||
| k_cache: torch.Tensor, | ||
| v_cache: torch.Tensor, | ||
| k_new: torch.Tensor, | ||
| v_new: torch.Tensor, | ||
| *, | ||
| causal: bool = True, | ||
| scale: Optional[float] = None, | ||
| key_padding_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """Ground truth: concat cache+new, then attend in strict fp32. | ||
| Delegates to ``NativeAttentionOp.forward_fp32`` so the fp32 golden path | ||
| is identical to prefill's. | ||
| """ | ||
| self._validate_decode_alignment(q, k_new, v_new) | ||
| k_full, v_full = self._concat_kv(k_cache, v_cache, k_new, v_new) | ||
| return self._attn.forward_fp32( | ||
| q, k_full, v_full, causal=causal, scale=scale, key_padding_mask=key_padding_mask | ||
| ) | ||
|
|
||
| # ------------------------------------------------------------------ # | ||
| # Helpers | ||
| # ------------------------------------------------------------------ # | ||
| @staticmethod | ||
| def _validate_decode_alignment( | ||
| q: torch.Tensor, | ||
| k_new: torch.Tensor, | ||
| v_new: torch.Tensor, | ||
| ) -> None: | ||
| """Enforce the contract that ``q`` holds exactly the newly appended | ||
| positions: ``Sq == S_new``. | ||
|
|
||
| q's rows are the queries for ``k_new``, so the causal offset inside | ||
| ``NativeAttentionOp`` (``Skv - Sq + 1``) is only correct when their seq | ||
| lengths match. A mismatched ``q`` would otherwise silently attend with | ||
| the wrong offset and return a wrong-but-finite result. Seq axis is dim=2 | ||
| in the ``[B, H, S, D]`` layout. | ||
| """ | ||
| sq, s_new_k, s_new_v = q.size(2), k_new.size(2), v_new.size(2) | ||
| if sq != s_new_k or sq != s_new_v: | ||
| raise ValueError( | ||
| "kv_cache attention expects q to hold exactly the new positions " | ||
| f"(Sq == S_new): got Sq={sq}, k_new={s_new_k}, v_new={s_new_v}." | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _concat_kv( | ||
| k_cache: torch.Tensor, | ||
| v_cache: torch.Tensor, | ||
| k_new: torch.Tensor, | ||
| v_new: torch.Tensor, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Append the step's new K/V to the cache along the seq axis (dim=2). | ||
|
|
||
| Layout is [B, Hkv, S, D], so the sequence axis is dim=2. Pure (no | ||
| in-place writes into the passed-in cache tensors). | ||
| """ | ||
| k_full = torch.cat([k_cache, k_new], dim=2) | ||
| v_full = torch.cat([v_cache, v_new], dim=2) | ||
| return k_full, v_full | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.