Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ jobs:
python -m pytest rl_engine/tests/test_dispatch.py -v
PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs

- name: Run Attention Ground-Truth Tests (CPU-safe)
run: |
python -m pytest tests/test_attention.py -v -k "not large and not gpu"

docs:
runs-on: ubuntu-latest
steps:
Expand Down
1 change: 1 addition & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ nav:
- Operators:
- operators/README.md
- operators/activation.md
- operators/attention.md
- operators/fused-logp.md
- operators/linear-logp.md
- operators/grpo-loss.md
Expand Down
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Every operator page should include:
## Current Pages

- [SiLU / SwiGLU Activation](activation.md)
- [Standard Attention](attention.md)
- [Fused LogP](fused-logp.md)
- [Fused Linear LogP](linear-logp.md)
- [GRPO Loss](grpo-loss.md)
Expand Down
159 changes: 159 additions & 0 deletions docs/operators/attention.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Standard Softmax Attention

The attention operator is the reduction core of the Qwen3/Llama transformer block. It is a
**WS1 ground-truth reference** (issue #108): a pure-PyTorch, fp32-accumulating definition of
the "correct answer" that downstream fused CUDA/Triton attention kernels (FlashAttention and
friends) are validated against.

- **`NativeAttentionOp`**: `out = softmax(Q Kᵀ · scale + masks) @ V` — a hand-written naive
softmax with a **fixed reduction order**, deliberately **not**
`F.scaled_dot_product_attention` / flash / mem-efficient attention (whose reduction order
is unspecified and would break the batch-invariance contract).

This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before*
the call (see the chain), so the `q`, `k` passed in are already normalized and rotated.

```text
q --\
k ----softmax(QKᵀ/√d + mask)·V--> out
v --/
```

## Entry Point
```python
from rl_engine.kernels.registry import kernel_registry

attn = kernel_registry.get_op("attention")

# Prefill: Sq == Skv ; Decode: Sq < Skv (one/few new queries against the full cache)
out = attn(q, k, v, causal=True) # [B, 32, Sq, 128]
out = attn(q, k, v, causal=True, scale=1.0 / 128 ** 0.5) # explicit scale
out = attn(q, k, v, causal=False, key_padding_mask=mask) # mask: [B, Skv] bool, True = keep
```

The op exposes the WS1 dual-path contract:

- `forward(...)` — computes in the input dtype, returns the input dtype (Axis-B accuracy
candidate / dtype-behavior path).
- `forward_fp32(...)` — upcasts to fp32, accumulates in fp32, returns fp32 (the ground-truth
golden path). It disables autocast and TF32 so it stays a true fp32 reference regardless of
the caller's ambient precision context.

> **Not the same as `"attn"`.** `kernel_registry.get_op("attn")` resolves to the production
> SDPA fallback (`PYTORCH_ATTN`); this ground-truth op is registered separately under
> `"attention"` (`PYTORCH_NATIVE_ATTENTION`). The two do not overlap.

## Backends

| Backend | Wrapper | Native symbol | Status |
| --- | --- | --- | --- |
| PyTorch fallback | `NativeAttentionOp` | None | fp32 ground-truth reference; CPU and any GPU. |
| CUDA / ROCm / Triton | — | — | Planned: downstream fused attention kernels validate against this reference. |

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `q` | `[B, Hq, Sq, D]` | float (fp16/bf16/fp32) | Qwen3-8B: `Hq=32`, `D=128`. |
| `k` | `[B, Hkv, Skv, D]` | float | Qwen3-8B: `Hkv=8` (GQA). `Hq` must be divisible by `Hkv`. |
| `v` | `[B, Hkv, Skv, D]` | float | Same head/seq layout as `k`. |
| `causal` | — | bool (kw, default `True`) | Upper-triangular mask at offset `Skv - Sq + 1`. |
| `scale` | — | float or `None` (kw) | `None` → `1/sqrt(D)` = `1/√128`. An explicit value (incl. `0.0`) is used verbatim. |
| `key_padding_mask` | `[B, Skv]` | bool or `None` (kw) | `True` = valid / keep, `False` = padding → that key column set to `-inf`. |
| output | `[B, Hq, Sq, D]` | `forward`: input dtype · `forward_fp32`: float32 | Heads precede seq (`[B, H, S, D]`). |

**GQA** (`Hq=32`, `Hkv=8`, group `g=4`): each KV head is replicated `g` times with
`repeat_interleave(g, dim=1)` (not `repeat`), so query head `h` maps to KV head `h // g`.

**Causal offset** `Skv - Sq + 1` anchors the queries to the end of the sequence, so a single
expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees
the whole cache).

Pure function — no randomness, no in-place mutation; device follows the inputs. `forward(...)`
preserves the input dtype, while `forward_fp32(...)` always returns fp32. Masks are built on
the inputs' device.

## Dispatch Behavior

`kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On
`cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op
(`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to this op. Calling it (`__call__` ->
`forward(...)`) computes in the input dtype; `forward_fp32(...)` is the explicit fp32 golden
path. When fused attention kernels land, they are prepended to the priority list and the native
op becomes the fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is
a separate dispatch chain and is unaffected.

## Accuracy

Reference semantics (`forward_fp32`, fp32 accumulation, TF32/autocast disabled):

```python
qf, kf, vf = q.float(), k.float(), v.float()
if Hkv != Hq: # GQA: replicate KV, 32 Q / 8 KV, r = 4
r = Hq // Hkv
kf = kf.repeat_interleave(r, dim=1)
vf = vf.repeat_interleave(r, dim=1)
scale = scale if scale is not None else 1.0 / math.sqrt(D) # D=128 → 1/√128
scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv]
if causal: # offset covers prefill + decode
m = torch.triu(torch.ones(Sq, Skv, dtype=torch.bool), diagonal=Skv - Sq + 1)
scores = scores.masked_fill(m, float("-inf"))
if key_padding_mask is not None: # True = keep ; False columns → -inf
scores = scores.masked_fill(~key_padding_mask[:, None, None, :], float("-inf"))
probs = torch.softmax(scores, dim=-1) # subtracts per-row max internally
out = torch.matmul(probs, vf) # [B, Hq, Sq, D]
```

- **Ground truth**: `forward_fp32` always accumulates in and returns fp32, with TF32 and
autocast disabled so it is not silently downcast by the caller's ambient context.
- **Dtype path**: `forward` runs the same math in the input dtype, so low-precision reductions
over the key dimension drift from the fp32 reference — Axis-B accuracy therefore uses a
tolerance, not bitwise equality.
- **Axis A — batch invariance**: each query row reduces over the keys independently of how many
sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`)
across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax
reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width
(e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small
tolerance (`atol=2e-6`), not bitwise, in IEEE 754.
- **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction`
row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak):

| dtype | max_abs / peak | threshold (rel-peak) |
| --- | --- | --- |
| bfloat16 | ~0.56 % | 3 % |
| float16 | ~0.07 % | 0.5 % |

## Performance Notes

Reference operator — no fused kernel or benchmark yet. Downstream fused attention kernels carry
their own benchmarks and are measured against this reference for correctness. At the LARGE
Qwen3-8B load point (`B=8`, `Skv=4096`, `Hq=32`) the fp32 scores tensor alone is ~17 GB and the
naive path peaks at ~3× that, so the LARGE smoke test is GPU-only and skips without enough
memory.

## Tests

```bash
python -m pytest tests/test_attention.py -v
```

Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile
autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard,
scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch
invariance (slice + chunked, bitwise; padding is near-equality only, see below), input purity,
gradient flow, registry dispatch, and a
GPU-only LARGE Qwen3-8B real-shape smoke test.

## Implementation Files

- `rl_engine/kernels/ops/pytorch/attention/standard_attn.py`
- `rl_engine/kernels/registry.py`
- `tests/test_attention.py`

## Known Limitations

- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work).
- `Hq` must be divisible by `Hkv` (raises `ValueError` otherwise).
- The naive path materializes the full `[B, Hq, Sq, Skv]` scores tensor — no query-chunking,
so the LARGE load point is memory-heavy and GPU-only.
- Covers softmax attention only; QK-Norm and RoPE are applied before the call.
191 changes: 191 additions & 0 deletions rl_engine/kernels/ops/pytorch/attention/standard_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from __future__ import annotations

import math
from contextlib import contextmanager, nullcontext
from typing import Optional

import torch


class NativeAttentionOp:
"""
Pure PyTorch native standard-softmax attention reference.
out = softmax(Q Kᵀ * scale + masks) @ V

Hand-written naive softmax -- deliberately NOT
``F.scaled_dot_product_attention`` / flash / mem-efficient attention, whose
reduction order is unspecified and would break the batch-invariance (Axis-A)
contract. This op defines the *correct answer* the fused kernels align to.

Qwen3-8B shapes: q ``[B, 32, Sq, 128]``, k/v ``[B, 8, Skv, 128]`` (GQA group
g = 32/8 = 4), scale = 1/sqrt(head_dim) = 1/sqrt(128). Heads precede seq in
the layout. This op covers ONLY the softmax attention; QK-Norm and RoPE are
applied *before* the call (see the chain test) -- the q,k passed in are
already normalized and rotated.

This is a reduction over the key dimension (Skv): the low-precision
``forward`` path accumulates in the input dtype and therefore drifts from
the fp32 ``forward_fp32`` ground truth, so Axis-B accuracy uses a tolerance
(``torch.allclose``), not bitwise equality. Axis-A batch invariance still
holds bitwise within a single dtype (each query row reduces over the keys
independently of how many sequences share the batch).

Masking conventions:
* causal=True -> upper-triangular -inf at diagonal Skv-Sq+1, valid for
both prefill (Sq==Skv) and decode (Sq<Skv).
* key_padding_mask ``[B, Skv]`` bool, True=valid / False=padding -> padded
key columns set to -inf (matches reference_ops.py: True=keep, False=mask).
"""

def __init__(self) -> None:
"""No state; the op is a pure function over (q, k, v, ...)."""

def __call__(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
causal: bool = True,
scale: Optional[float] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Alias for ``forward`` so the op is callable like a module."""
return self.forward(q, k, v, causal=causal, scale=scale, key_padding_mask=key_padding_mask)

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
causal: bool = True,
scale: Optional[float] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Canonical entry: attend in the input dtype, output the input dtype.
This is the dtype-behavior path used as the Axis-B accuracy candidate.
"""
return self._attention(
q,
k,
v,
causal=causal,
scale=scale,
key_padding_mask=key_padding_mask,
compute_dtype=q.dtype,
output_dtype=q.dtype,
)

def forward_fp32(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
causal: bool = True,
scale: Optional[float] = None,
key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Ground truth: upcast to fp32, accumulate in fp32, force fp32 output.

The whole score->softmax->value path is wrapped to disable autocast and
TF32 so this stays a true fp32 reference regardless of the caller's
ambient precision context.
"""
return self._attention(
q,
k,
v,
causal=causal,
scale=scale,
key_padding_mask=key_padding_mask,
compute_dtype=torch.float32,
output_dtype=torch.float32,
strict_fp32=True,
)

# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
causal: bool,
scale: Optional[float],
key_padding_mask: Optional[torch.Tensor],
compute_dtype: torch.dtype,
output_dtype: torch.dtype,
strict_fp32: bool = False,
) -> torch.Tensor:
"""Core softmax attention: cast to ``compute_dtype``, score, mask, softmax,
weighted-sum over V, cast out. ``strict_fp32`` disables autocast/TF32 so the
fp32 reference is not silently downcast by the caller's ambient context.
"""
Hq, Sq, D = q.shape[1], q.shape[2], q.shape[3]
Hkv, Skv = k.shape[1], k.shape[2]
if Hq % Hkv != 0:
raise ValueError(f"Hq={Hq} not divisible by Hkv={Hkv} (GQA group)")

ctx = NativeAttentionOp._strict_fp32_math(q.device.type) if strict_fp32 else nullcontext()
with ctx:
qf = q.to(compute_dtype)
kf = k.to(compute_dtype)
vf = v.to(compute_dtype)

# GQA: replicate each KV head g=Hq//Hkv times (Qwen3: 32/8 -> 4).
# repeat_interleave (not repeat) keeps each KV head's copies adjacent
# so query head h maps to KV head h // g.
if Hkv != Hq:
r = Hq // Hkv
kf = kf.repeat_interleave(r, dim=1)
vf = vf.repeat_interleave(r, dim=1)

# scale defaults to 1/sqrt(head_dim); `is not None` so an explicit 0.0 is kept.
scale = scale if scale is not None else (1.0 / math.sqrt(D))
scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv]
Comment thread
Flink-ddd marked this conversation as resolved.

# Causal: offset Skv-Sq+1 covers prefill (Sq==Skv) and decode (Sq<Skv).
if causal:
causal_mask = torch.triu(
torch.ones(Sq, Skv, dtype=torch.bool, device=q.device),
diagonal=Skv - Sq + 1,
)
scores = scores.masked_fill(causal_mask, float("-inf"))

# key_padding_mask [B, Skv]: True=valid; False columns -> -inf.
if key_padding_mask is not None:
pad = ~key_padding_mask
scores = scores.masked_fill(pad[:, None, None, :], float("-inf"))

probs = torch.softmax(scores, dim=-1) # subtracts row max internally
if key_padding_mask is not None:
# A query whose every valid key is padded out has an all -inf row;
# softmax would emit NaN. Define such fully-masked rows as 0 (no key
# contributes), keeping outputs/grads finite. Only key_padding_mask can
# fully mask a row -- causal alone always leaves a query its own key --
# so this guard lives in the padding branch and the no-pad path is
# untouched. Row-independent, so Axis-A batch invariance still holds.
all_masked = ~torch.isfinite(scores).any(dim=-1, keepdim=True)
probs = torch.where(all_masked, torch.zeros_like(probs), probs)
out = torch.matmul(probs, vf) # [B, Hq, Sq, D]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return out.to(output_dtype)

@staticmethod
@contextmanager
def _strict_fp32_math(device_type: str):
"""Disable autocast and TF32 for a true fp32 path, restoring state after."""
prev_tf32 = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
try:
with torch.autocast(device_type=device_type, enabled=False):
yield
finally:
torch.backends.cuda.matmul.allow_tf32 = prev_tf32
Loading
Loading