Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ jobs:
python -m pytest rl_engine/tests/test_dispatch.py -v
PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs

- name: Run Attention Ground-Truth Tests (CPU-safe)
run: |
python -m pytest tests/test_attention.py -v -k "not large and not gpu"

- name: Run KV-Cache Attention Ground-Truth Tests (CPU-safe)
run: |
python -m pytest tests/test_kv_cache_attention.py -v -k "not large and not gpu"

docs:
runs-on: ubuntu-latest
steps:
Expand Down
1 change: 1 addition & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ nav:
- Operators:
- operators/README.md
- operators/activation.md
- operators/attention.md
- operators/fused-logp.md
- operators/linear-logp.md
- operators/grpo-loss.md
Expand Down
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Every operator page should include:
## Current Pages

- [SiLU / SwiGLU Activation](activation.md)
- [Standard Attention](attention.md)
- [Fused LogP](fused-logp.md)
- [Fused Linear LogP](linear-logp.md)
- [GRPO Loss](grpo-loss.md)
Expand Down
159 changes: 159 additions & 0 deletions docs/operators/attention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Standard Softmax Attention

The attention operator is the reduction core of the Qwen3/Llama transformer block. It is a
**WS1 ground-truth reference** (issue #108): a pure-PyTorch, fp32-accumulating definition of
the "correct answer" that downstream fused CUDA/Triton attention kernels (FlashAttention and
friends) are validated against.

- **`NativeAttentionOp`**: `out = softmax(Q Kᵀ · scale + masks) @ V` — a hand-written naive
softmax with a **fixed reduction order**, deliberately **not**
`F.scaled_dot_product_attention` / flash / mem-efficient attention (whose reduction order
is unspecified and would break the batch-invariance contract).

This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before*
the call (see the chain), so the `q`, `k` passed in are already normalized and rotated.

```text
q --\
k ----softmax(QKᵀ/√d + mask)·V--> out
v --/
```

## Entry Point
```python
from rl_engine.kernels.registry import kernel_registry

attn = kernel_registry.get_op("attention")

# Prefill: Sq == Skv ; Decode: Sq < Skv (one/few new queries against the full cache)
out = attn(q, k, v, causal=True) # [B, 32, Sq, 128]
out = attn(q, k, v, causal=True, scale=1.0 / 128 ** 0.5) # explicit scale
out = attn(q, k, v, causal=False, key_padding_mask=mask) # mask: [B, Skv] bool, True = keep
```

The op exposes the WS1 dual-path contract:

- `forward(...)` — computes in the input dtype, returns the input dtype (Axis-B accuracy
candidate / dtype-behavior path).
- `forward_fp32(...)` — upcasts to fp32, accumulates in fp32, returns fp32 (the ground-truth
golden path). It disables autocast and TF32 so it stays a true fp32 reference regardless of
the caller's ambient precision context.

> **Not the same as `"attn"`.** `kernel_registry.get_op("attn")` resolves to the production
> SDPA fallback (`PYTORCH_ATTN`); this ground-truth op is registered separately under
> `"attention"` (`PYTORCH_NATIVE_ATTENTION`). The two do not overlap.

## Backends

| Backend | Wrapper | Native symbol | Status |
| --- | --- | --- | --- |
| PyTorch fallback | `NativeAttentionOp` | None | fp32 ground-truth reference; CPU and any GPU. |
| CUDA / ROCm / Triton | — | — | Planned: downstream fused attention kernels validate against this reference. |

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `q` | `[B, Hq, Sq, D]` | float (fp16/bf16/fp32) | Qwen3-8B: `Hq=32`, `D=128`. |
| `k` | `[B, Hkv, Skv, D]` | float | Qwen3-8B: `Hkv=8` (GQA). `Hq` must be divisible by `Hkv`. |
| `v` | `[B, Hkv, Skv, D]` | float | Same head/seq layout as `k`. |
| `causal` | — | bool (kw, default `True`) | Upper-triangular mask at offset `Skv - Sq + 1`. |
| `scale` | — | float or `None` (kw) | `None` → `1/sqrt(D)` = `1/√128`. An explicit value (incl. `0.0`) is used verbatim. |
| `key_padding_mask` | `[B, Skv]` | bool or `None` (kw) | `True` = valid / keep, `False` = padding → that key column set to `-inf`. |
| output | `[B, Hq, Sq, D]` | `forward`: input dtype · `forward_fp32`: float32 | Heads precede seq (`[B, H, S, D]`). |

**GQA** (`Hq=32`, `Hkv=8`, group `g=4`): each KV head is replicated `g` times with
`repeat_interleave(g, dim=1)` (not `repeat`), so query head `h` maps to KV head `h // g`.

**Causal offset** `Skv - Sq + 1` anchors the queries to the end of the sequence, so a single
expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees
the whole cache).

Pure function — no randomness, no in-place mutation; device follows the inputs. `forward(...)`
preserves the input dtype, while `forward_fp32(...)` always returns fp32. Masks are built on
the inputs' device.

## Dispatch Behavior

`kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On
`cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op
(`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to this op. Calling it (`__call__` ->
`forward(...)`) computes in the input dtype; `forward_fp32(...)` is the explicit fp32 golden
path. When fused attention kernels land, they are prepended to the priority list and the native
op becomes the fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is
a separate dispatch chain and is unaffected.

## Accuracy

Reference semantics (`forward_fp32`, fp32 accumulation, TF32/autocast disabled):

```python
qf, kf, vf = q.float(), k.float(), v.float()
if Hkv != Hq: # GQA: replicate KV, 32 Q / 8 KV, r = 4
r = Hq // Hkv
kf = kf.repeat_interleave(r, dim=1)
vf = vf.repeat_interleave(r, dim=1)
scale = scale if scale is not None else 1.0 / math.sqrt(D) # D=128 → 1/√128
scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv]
if causal: # offset covers prefill + decode
m = torch.triu(torch.ones(Sq, Skv, dtype=torch.bool), diagonal=Skv - Sq + 1)
scores = scores.masked_fill(m, float("-inf"))
if key_padding_mask is not None: # True = keep ; False columns → -inf
scores = scores.masked_fill(~key_padding_mask[:, None, None, :], float("-inf"))
probs = torch.softmax(scores, dim=-1) # subtracts per-row max internally
out = torch.matmul(probs, vf) # [B, Hq, Sq, D]
```

- **Ground truth**: `forward_fp32` always accumulates in and returns fp32, with TF32 and
autocast disabled so it is not silently downcast by the caller's ambient context.
- **Dtype path**: `forward` runs the same math in the input dtype, so low-precision reductions
over the key dimension drift from the fp32 reference — Axis-B accuracy therefore uses a
tolerance, not bitwise equality.
- **Axis A — batch invariance**: each query row reduces over the keys independently of how many
sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`)
across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax
reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width
(e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small
tolerance (`atol=2e-6`), not bitwise, in IEEE 754.
- **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction`
row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak):

| dtype | max_abs / peak | threshold (rel-peak) |
| --- | --- | --- |
| bfloat16 | ~0.56 % | 3 % |
| float16 | ~0.07 % | 0.5 % |

## Performance Notes

Reference operator — no fused kernel or benchmark yet. Downstream fused attention kernels carry
their own benchmarks and are measured against this reference for correctness. At the LARGE
Qwen3-8B load point (`B=8`, `Skv=4096`, `Hq=32`) the fp32 scores tensor alone is ~17 GB and the
naive path peaks at ~3× that, so the LARGE smoke test is GPU-only and skips without enough
memory.

## Tests

```bash
python -m pytest tests/test_attention.py -v
```

Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile
autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard,
scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch
invariance (slice + chunked, bitwise; padding is near-equality only, see below), input purity,
gradient flow, registry dispatch, and a
GPU-only LARGE Qwen3-8B real-shape smoke test.

## Implementation Files

- `rl_engine/kernels/ops/pytorch/attention/standard_attn.py`
- `rl_engine/kernels/registry.py`
- `tests/test_attention.py`

## Known Limitations

- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work).
- `Hq` must be divisible by `Hkv` (raises `ValueError` otherwise).
- The naive path materializes the full `[B, Hq, Sq, Skv]` scores tensor — no query-chunking,
so the LARGE load point is memory-heavy and GPU-only.
- Covers softmax attention only; QK-Norm and RoPE are applied before the call.
161 changes: 161 additions & 0 deletions rl_engine/kernels/ops/pytorch/attention/kv_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from __future__ import annotations

from typing import Optional

import torch

from rl_engine.kernels.ops.pytorch.attention.standard_attn import NativeAttentionOp


class NativeKVCacheAttnOp:
"""
Pure PyTorch native KV-cache attention reference (ISSUE #108 WS1).

Decode/incremental attention: the past keys/values live in a cache and the
step's new keys/values are appended before attending::

k_full = cat([k_cache, k_new], dim=2) # along the seq axis
v_full = cat([v_cache, v_new], dim=2)
out = NativeAttentionOp().forward_fp32(q, k_full, v_full, ...)

The whole point is that it delegates to the *same* ``NativeAttentionOp`` used
for full-sequence (prefill) attention: prefill and decode therefore share one
reduction path, which is what makes rollout (decode) numerically consistent
with training (prefill). Re-implementing the softmax here would defeat that.

Qwen3-8B shapes (synthetic tensors, no checkpoint): q ``[B, 32, Sq, 128]``,
cache/new k,v ``[B, 8, S_past/S_new, 128]`` (GQA group g = 32/8 = 4). Heads
precede seq in the layout; the GQA KV replication happens inside
``NativeAttentionOp``, not here. This is a reduction over the full key length
Skv = S_past + S_new.

Alignment assumption: q's ``Sq`` rows are the *last* Sq positions of the full
sequence -- i.e. the queries for ``k_new`` -- so callers pass ``Sq == S_new``
(decode: Sq == S_new == 1). The causal offset ``Skv - Sq + 1`` inside
``NativeAttentionOp`` then lets each new query see the whole cache plus the
new tokens up to and including itself, for both decode and chunked prefill.

Masking conventions (forwarded verbatim to ``NativeAttentionOp``):
* causal=True -> upper-triangular -inf at diagonal Skv-Sq+1.
* key_padding_mask ``[B, Skv]`` bool over the *concatenated* length
(S_past + S_new), True=valid / False=padding.

Only the attention output is returned; producing an updated cache is a
caller/runtime concern and is not part of this numerical contract.
"""

def __init__(self) -> None:
"""No state; the op is a pure function over (q, k_cache, v_cache, k_new, v_new, ...)."""
self._attn = NativeAttentionOp()

def __call__(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_new: torch.Tensor,
v_new: torch.Tensor,
*,
causal: bool = True,
scale: Optional[float] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Alias for ``forward`` so the op is callable like a module."""
return self.forward(
q,
k_cache,
v_cache,
k_new,
v_new,
causal=causal,
scale=scale,
key_padding_mask=key_padding_mask,
)

def forward(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_new: torch.Tensor,
v_new: torch.Tensor,
*,
causal: bool = True,
scale: Optional[float] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Canonical entry: concat cache+new, then attend in the input dtype.
Delegates to ``NativeAttentionOp.forward`` (the Axis-B dtype path).
"""
self._validate_decode_alignment(q, k_new, v_new)
k_full, v_full = self._concat_kv(k_cache, v_cache, k_new, v_new)
return self._attn.forward(
q, k_full, v_full, causal=causal, scale=scale, key_padding_mask=key_padding_mask
)

def forward_fp32(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_new: torch.Tensor,
v_new: torch.Tensor,
*,
causal: bool = True,
scale: Optional[float] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Ground truth: concat cache+new, then attend in strict fp32.
Delegates to ``NativeAttentionOp.forward_fp32`` so the fp32 golden path
is identical to prefill's.
"""
self._validate_decode_alignment(q, k_new, v_new)
k_full, v_full = self._concat_kv(k_cache, v_cache, k_new, v_new)
return self._attn.forward_fp32(
q, k_full, v_full, causal=causal, scale=scale, key_padding_mask=key_padding_mask
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _validate_decode_alignment(
q: torch.Tensor,
k_new: torch.Tensor,
v_new: torch.Tensor,
) -> None:
"""Enforce the contract that ``q`` holds exactly the newly appended
positions: ``Sq == S_new``.

q's rows are the queries for ``k_new``, so the causal offset inside
``NativeAttentionOp`` (``Skv - Sq + 1``) is only correct when their seq
lengths match. A mismatched ``q`` would otherwise silently attend with
the wrong offset and return a wrong-but-finite result. Seq axis is dim=2
in the ``[B, H, S, D]`` layout.
"""
sq, s_new_k, s_new_v = q.size(2), k_new.size(2), v_new.size(2)
if sq != s_new_k or sq != s_new_v:
raise ValueError(
"kv_cache attention expects q to hold exactly the new positions "
f"(Sq == S_new): got Sq={sq}, k_new={s_new_k}, v_new={s_new_v}."
)

@staticmethod
def _concat_kv(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_new: torch.Tensor,
v_new: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Append the step's new K/V to the cache along the seq axis (dim=2).

Layout is [B, Hkv, S, D], so the sequence axis is dim=2. Pure (no
in-place writes into the passed-in cache tensors).
"""
k_full = torch.cat([k_cache, k_new], dim=2)
v_full = torch.cat([v_cache, v_new], dim=2)
return k_full, v_full
Loading
Loading