diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 755f426..92cd043 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,6 +68,14 @@ jobs: python -m pytest rl_engine/tests/test_dispatch.py -v PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs + - name: Run Attention Ground-Truth Tests (CPU-safe) + run: | + python -m pytest tests/test_attention.py -v -k "not large and not gpu" + + - name: Run KV-Cache Attention Ground-Truth Tests (CPU-safe) + run: | + python -m pytest tests/test_kv_cache_attention.py -v -k "not large and not gpu" + docs: runs-on: ubuntu-latest steps: diff --git a/docs/.nav.yml b/docs/.nav.yml index 9321713..1274acb 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -12,6 +12,7 @@ nav: - Operators: - operators/README.md - operators/activation.md + - operators/attention.md - operators/fused-logp.md - operators/linear-logp.md - operators/grpo-loss.md diff --git a/docs/operators/README.md b/docs/operators/README.md index eb2f5de..aa73396 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -19,6 +19,7 @@ Every operator page should include: ## Current Pages - [SiLU / SwiGLU Activation](activation.md) +- [Standard Attention](attention.md) - [Fused LogP](fused-logp.md) - [Fused Linear LogP](linear-logp.md) - [GRPO Loss](grpo-loss.md) diff --git a/docs/operators/attention.md b/docs/operators/attention.md new file mode 100644 index 0000000..e3cb5f9 --- /dev/null +++ b/docs/operators/attention.md @@ -0,0 +1,159 @@ +# Standard Softmax Attention + +The attention operator is the reduction core of the Qwen3/Llama transformer block. It is a +**WS1 ground-truth reference** (issue #108): a pure-PyTorch, fp32-accumulating definition of +the "correct answer" that downstream fused CUDA/Triton attention kernels (FlashAttention and +friends) are validated against. + +- **`NativeAttentionOp`**: `out = softmax(Q Kᵀ · scale + masks) @ V` — a hand-written naive + softmax with a **fixed reduction order**, deliberately **not** + `F.scaled_dot_product_attention` / flash / mem-efficient attention (whose reduction order + is unspecified and would break the batch-invariance contract). + +This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before* +the call (see the chain), so the `q`, `k` passed in are already normalized and rotated. + +```text +q --\ +k ----softmax(QKᵀ/√d + mask)·V--> out +v --/ +``` + +## Entry Point +```python +from rl_engine.kernels.registry import kernel_registry + +attn = kernel_registry.get_op("attention") + +# Prefill: Sq == Skv ; Decode: Sq < Skv (one/few new queries against the full cache) +out = attn(q, k, v, causal=True) # [B, 32, Sq, 128] +out = attn(q, k, v, causal=True, scale=1.0 / 128 ** 0.5) # explicit scale +out = attn(q, k, v, causal=False, key_padding_mask=mask) # mask: [B, Skv] bool, True = keep +``` + +The op exposes the WS1 dual-path contract: + +- `forward(...)` — computes in the input dtype, returns the input dtype (Axis-B accuracy + candidate / dtype-behavior path). +- `forward_fp32(...)` — upcasts to fp32, accumulates in fp32, returns fp32 (the ground-truth + golden path). It disables autocast and TF32 so it stays a true fp32 reference regardless of + the caller's ambient precision context. + +> **Not the same as `"attn"`.** `kernel_registry.get_op("attn")` resolves to the production +> SDPA fallback (`PYTORCH_ATTN`); this ground-truth op is registered separately under +> `"attention"` (`PYTORCH_NATIVE_ATTENTION`). The two do not overlap. + +## Backends + +| Backend | Wrapper | Native symbol | Status | +| --- | --- | --- | --- | +| PyTorch fallback | `NativeAttentionOp` | None | fp32 ground-truth reference; CPU and any GPU. | +| CUDA / ROCm / Triton | — | — | Planned: downstream fused attention kernels validate against this reference. | + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `q` | `[B, Hq, Sq, D]` | float (fp16/bf16/fp32) | Qwen3-8B: `Hq=32`, `D=128`. | +| `k` | `[B, Hkv, Skv, D]` | float | Qwen3-8B: `Hkv=8` (GQA). `Hq` must be divisible by `Hkv`. | +| `v` | `[B, Hkv, Skv, D]` | float | Same head/seq layout as `k`. | +| `causal` | — | bool (kw, default `True`) | Upper-triangular mask at offset `Skv - Sq + 1`. | +| `scale` | — | float or `None` (kw) | `None` → `1/sqrt(D)` = `1/√128`. An explicit value (incl. `0.0`) is used verbatim. | +| `key_padding_mask` | `[B, Skv]` | bool or `None` (kw) | `True` = valid / keep, `False` = padding → that key column set to `-inf`. | +| output | `[B, Hq, Sq, D]` | `forward`: input dtype · `forward_fp32`: float32 | Heads precede seq (`[B, H, S, D]`). | + +**GQA** (`Hq=32`, `Hkv=8`, group `g=4`): each KV head is replicated `g` times with +`repeat_interleave(g, dim=1)` (not `repeat`), so query head `h` maps to KV head `h // g`. + +**Causal offset** `Skv - Sq + 1` anchors the queries to the end of the sequence, so a single +expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees +the whole cache). + +Pure function — no randomness, no in-place mutation; device follows the inputs. `forward(...)` +preserves the input dtype, while `forward_fp32(...)` always returns fp32. Masks are built on +the inputs' device. + +## Dispatch Behavior + +`kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On +`cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op +(`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to this op. Calling it (`__call__` -> +`forward(...)`) computes in the input dtype; `forward_fp32(...)` is the explicit fp32 golden +path. When fused attention kernels land, they are prepended to the priority list and the native +op becomes the fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is +a separate dispatch chain and is unaffected. + +## Accuracy + +Reference semantics (`forward_fp32`, fp32 accumulation, TF32/autocast disabled): + +```python +qf, kf, vf = q.float(), k.float(), v.float() +if Hkv != Hq: # GQA: replicate KV, 32 Q / 8 KV, r = 4 + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) +scale = scale if scale is not None else 1.0 / math.sqrt(D) # D=128 → 1/√128 +scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] +if causal: # offset covers prefill + decode + m = torch.triu(torch.ones(Sq, Skv, dtype=torch.bool), diagonal=Skv - Sq + 1) + scores = scores.masked_fill(m, float("-inf")) +if key_padding_mask is not None: # True = keep ; False columns → -inf + scores = scores.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) +probs = torch.softmax(scores, dim=-1) # subtracts per-row max internally +out = torch.matmul(probs, vf) # [B, Hq, Sq, D] +``` + +- **Ground truth**: `forward_fp32` always accumulates in and returns fp32, with TF32 and + autocast disabled so it is not silently downcast by the caller's ambient context. +- **Dtype path**: `forward` runs the same math in the input dtype, so low-precision reductions + over the key dimension drift from the fp32 reference — Axis-B accuracy therefore uses a + tolerance, not bitwise equality. +- **Axis A — batch invariance**: each query row reduces over the keys independently of how many + sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`) + across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax + reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width + (e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small + tolerance (`atol=2e-6`), not bitwise, in IEEE 754. +- **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction` + row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak): + + | dtype | max_abs / peak | threshold (rel-peak) | + | --- | --- | --- | + | bfloat16 | ~0.56 % | 3 % | + | float16 | ~0.07 % | 0.5 % | + +## Performance Notes + +Reference operator — no fused kernel or benchmark yet. Downstream fused attention kernels carry +their own benchmarks and are measured against this reference for correctness. At the LARGE +Qwen3-8B load point (`B=8`, `Skv=4096`, `Hq=32`) the fp32 scores tensor alone is ~17 GB and the +naive path peaks at ~3× that, so the LARGE smoke test is GPU-only and skips without enough +memory. + +## Tests + +```bash +python -m pytest tests/test_attention.py -v +``` + +Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile +autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard, +scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch +invariance (slice + chunked, bitwise; padding is near-equality only, see below), input purity, +gradient flow, registry dispatch, and a +GPU-only LARGE Qwen3-8B real-shape smoke test. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/attention/standard_attn.py` +- `rl_engine/kernels/registry.py` +- `tests/test_attention.py` + +## Known Limitations + +- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). +- `Hq` must be divisible by `Hkv` (raises `ValueError` otherwise). +- The naive path materializes the full `[B, Hq, Sq, Skv]` scores tensor — no query-chunking, + so the LARGE load point is memory-heavy and GPU-only. +- Covers softmax attention only; QK-Norm and RoPE are applied before the call. diff --git a/rl_engine/kernels/ops/pytorch/attention/kv_cache.py b/rl_engine/kernels/ops/pytorch/attention/kv_cache.py new file mode 100644 index 0000000..a53931e --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/attention/kv_cache.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +from typing import Optional + +import torch + +from rl_engine.kernels.ops.pytorch.attention.standard_attn import NativeAttentionOp + + +class NativeKVCacheAttnOp: + """ + Pure PyTorch native KV-cache attention reference (ISSUE #108 WS1). + + Decode/incremental attention: the past keys/values live in a cache and the + step's new keys/values are appended before attending:: + + k_full = cat([k_cache, k_new], dim=2) # along the seq axis + v_full = cat([v_cache, v_new], dim=2) + out = NativeAttentionOp().forward_fp32(q, k_full, v_full, ...) + + The whole point is that it delegates to the *same* ``NativeAttentionOp`` used + for full-sequence (prefill) attention: prefill and decode therefore share one + reduction path, which is what makes rollout (decode) numerically consistent + with training (prefill). Re-implementing the softmax here would defeat that. + + Qwen3-8B shapes (synthetic tensors, no checkpoint): q ``[B, 32, Sq, 128]``, + cache/new k,v ``[B, 8, S_past/S_new, 128]`` (GQA group g = 32/8 = 4). Heads + precede seq in the layout; the GQA KV replication happens inside + ``NativeAttentionOp``, not here. This is a reduction over the full key length + Skv = S_past + S_new. + + Alignment assumption: q's ``Sq`` rows are the *last* Sq positions of the full + sequence -- i.e. the queries for ``k_new`` -- so callers pass ``Sq == S_new`` + (decode: Sq == S_new == 1). The causal offset ``Skv - Sq + 1`` inside + ``NativeAttentionOp`` then lets each new query see the whole cache plus the + new tokens up to and including itself, for both decode and chunked prefill. + + Masking conventions (forwarded verbatim to ``NativeAttentionOp``): + * causal=True -> upper-triangular -inf at diagonal Skv-Sq+1. + * key_padding_mask ``[B, Skv]`` bool over the *concatenated* length + (S_past + S_new), True=valid / False=padding. + + Only the attention output is returned; producing an updated cache is a + caller/runtime concern and is not part of this numerical contract. + """ + + def __init__(self) -> None: + """No state; the op is a pure function over (q, k_cache, v_cache, k_new, v_new, ...).""" + self._attn = NativeAttentionOp() + + def __call__( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Alias for ``forward`` so the op is callable like a module.""" + return self.forward( + q, + k_cache, + v_cache, + k_new, + v_new, + causal=causal, + scale=scale, + key_padding_mask=key_padding_mask, + ) + + def forward( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Canonical entry: concat cache+new, then attend in the input dtype. + Delegates to ``NativeAttentionOp.forward`` (the Axis-B dtype path). + """ + self._validate_decode_alignment(q, k_new, v_new) + k_full, v_full = self._concat_kv(k_cache, v_cache, k_new, v_new) + return self._attn.forward( + q, k_full, v_full, causal=causal, scale=scale, key_padding_mask=key_padding_mask + ) + + def forward_fp32( + self, + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Ground truth: concat cache+new, then attend in strict fp32. + Delegates to ``NativeAttentionOp.forward_fp32`` so the fp32 golden path + is identical to prefill's. + """ + self._validate_decode_alignment(q, k_new, v_new) + k_full, v_full = self._concat_kv(k_cache, v_cache, k_new, v_new) + return self._attn.forward_fp32( + q, k_full, v_full, causal=causal, scale=scale, key_padding_mask=key_padding_mask + ) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + @staticmethod + def _validate_decode_alignment( + q: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + ) -> None: + """Enforce the contract that ``q`` holds exactly the newly appended + positions: ``Sq == S_new``. + + q's rows are the queries for ``k_new``, so the causal offset inside + ``NativeAttentionOp`` (``Skv - Sq + 1``) is only correct when their seq + lengths match. A mismatched ``q`` would otherwise silently attend with + the wrong offset and return a wrong-but-finite result. Seq axis is dim=2 + in the ``[B, H, S, D]`` layout. + """ + sq, s_new_k, s_new_v = q.size(2), k_new.size(2), v_new.size(2) + if sq != s_new_k or sq != s_new_v: + raise ValueError( + "kv_cache attention expects q to hold exactly the new positions " + f"(Sq == S_new): got Sq={sq}, k_new={s_new_k}, v_new={s_new_v}." + ) + + @staticmethod + def _concat_kv( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_new: torch.Tensor, + v_new: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Append the step's new K/V to the cache along the seq axis (dim=2). + + Layout is [B, Hkv, S, D], so the sequence axis is dim=2. Pure (no + in-place writes into the passed-in cache tensors). + """ + k_full = torch.cat([k_cache, k_new], dim=2) + v_full = torch.cat([v_cache, v_new], dim=2) + return k_full, v_full diff --git a/rl_engine/kernels/ops/pytorch/attention/standard_attn.py b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py new file mode 100644 index 0000000..804efb6 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import math +from contextlib import contextmanager, nullcontext +from typing import Optional + +import torch + + +class NativeAttentionOp: + """ + Pure PyTorch native standard-softmax attention reference. + out = softmax(Q Kᵀ * scale + masks) @ V + + Hand-written naive softmax -- deliberately NOT + ``F.scaled_dot_product_attention`` / flash / mem-efficient attention, whose + reduction order is unspecified and would break the batch-invariance (Axis-A) + contract. This op defines the *correct answer* the fused kernels align to. + + Qwen3-8B shapes: q ``[B, 32, Sq, 128]``, k/v ``[B, 8, Skv, 128]`` (GQA group + g = 32/8 = 4), scale = 1/sqrt(head_dim) = 1/sqrt(128). Heads precede seq in + the layout. This op covers ONLY the softmax attention; QK-Norm and RoPE are + applied *before* the call (see the chain test) -- the q,k passed in are + already normalized and rotated. + + This is a reduction over the key dimension (Skv): the low-precision + ``forward`` path accumulates in the input dtype and therefore drifts from + the fp32 ``forward_fp32`` ground truth, so Axis-B accuracy uses a tolerance + (``torch.allclose``), not bitwise equality. Axis-A batch invariance still + holds bitwise within a single dtype (each query row reduces over the keys + independently of how many sequences share the batch). + + Masking conventions: + * causal=True -> upper-triangular -inf at diagonal Skv-Sq+1, valid for + both prefill (Sq==Skv) and decode (Sq padded + key columns set to -inf (matches reference_ops.py: True=keep, False=mask). + """ + + def __init__(self) -> None: + """No state; the op is a pure function over (q, k, v, ...).""" + + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Alias for ``forward`` so the op is callable like a module.""" + return self.forward(q, k, v, causal=causal, scale=scale, key_padding_mask=key_padding_mask) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Canonical entry: attend in the input dtype, output the input dtype. + This is the dtype-behavior path used as the Axis-B accuracy candidate. + """ + return self._attention( + q, + k, + v, + causal=causal, + scale=scale, + key_padding_mask=key_padding_mask, + compute_dtype=q.dtype, + output_dtype=q.dtype, + ) + + def forward_fp32( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Ground truth: upcast to fp32, accumulate in fp32, force fp32 output. + + The whole score->softmax->value path is wrapped to disable autocast and + TF32 so this stays a true fp32 reference regardless of the caller's + ambient precision context. + """ + return self._attention( + q, + k, + v, + causal=causal, + scale=scale, + key_padding_mask=key_padding_mask, + compute_dtype=torch.float32, + output_dtype=torch.float32, + strict_fp32=True, + ) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + @staticmethod + def _attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool, + scale: Optional[float], + key_padding_mask: Optional[torch.Tensor], + compute_dtype: torch.dtype, + output_dtype: torch.dtype, + strict_fp32: bool = False, + ) -> torch.Tensor: + """Core softmax attention: cast to ``compute_dtype``, score, mask, softmax, + weighted-sum over V, cast out. ``strict_fp32`` disables autocast/TF32 so the + fp32 reference is not silently downcast by the caller's ambient context. + """ + Hq, Sq, D = q.shape[1], q.shape[2], q.shape[3] + Hkv, Skv = k.shape[1], k.shape[2] + if Hq % Hkv != 0: + raise ValueError(f"Hq={Hq} not divisible by Hkv={Hkv} (GQA group)") + + ctx = NativeAttentionOp._strict_fp32_math(q.device.type) if strict_fp32 else nullcontext() + with ctx: + qf = q.to(compute_dtype) + kf = k.to(compute_dtype) + vf = v.to(compute_dtype) + + # GQA: replicate each KV head g=Hq//Hkv times (Qwen3: 32/8 -> 4). + # repeat_interleave (not repeat) keeps each KV head's copies adjacent + # so query head h maps to KV head h // g. + if Hkv != Hq: + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) + + # scale defaults to 1/sqrt(head_dim); `is not None` so an explicit 0.0 is kept. + scale = scale if scale is not None else (1.0 / math.sqrt(D)) + scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] + + # Causal: offset Skv-Sq+1 covers prefill (Sq==Skv) and decode (Sq -inf. + if key_padding_mask is not None: + pad = ~key_padding_mask + scores = scores.masked_fill(pad[:, None, None, :], float("-inf")) + + probs = torch.softmax(scores, dim=-1) # subtracts row max internally + if key_padding_mask is not None: + # A query whose every valid key is padded out has an all -inf row; + # softmax would emit NaN. Define such fully-masked rows as 0 (no key + # contributes), keeping outputs/grads finite. Only key_padding_mask can + # fully mask a row -- causal alone always leaves a query its own key -- + # so this guard lives in the padding branch and the no-pad path is + # untouched. Row-independent, so Axis-A batch invariance still holds. + all_masked = ~torch.isfinite(scores).any(dim=-1, keepdim=True) + probs = torch.where(all_masked, torch.zeros_like(probs), probs) + out = torch.matmul(probs, vf) # [B, Hq, Sq, D] + return out.to(output_dtype) + + @staticmethod + @contextmanager + def _strict_fp32_math(device_type: str): + """Disable autocast and TF32 for a true fp32 path, restoring state after.""" + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + with torch.autocast(device_type=device_type, enabled=False): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 5e63d84..be3f931 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -59,6 +59,16 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): PYTORCH_NATIVE_SILU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSiLUOp" PYTORCH_NATIVE_SWIGLU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSwiGLUOp" + # WS1 pure-PyTorch ground-truth attention reference (hand-written fp32 softmax). + # Distinct from PYTORCH_ATTN above, which is the production SDPA fallback. + PYTORCH_NATIVE_ATTENTION = ( + "rl_engine.kernels.ops.pytorch.attention.standard_attn.NativeAttentionOp" + ) + # WS1 pure-PyTorch ground-truth KV-cache (decode/incremental) attention + # reference; concats cache+new then reuses the standard attention reduction. + PYTORCH_NATIVE_KV_CACHE_ATTN = ( + "rl_engine.kernels.ops.pytorch.attention.kv_cache.NativeKVCacheAttnOp" + ) # WS1 pure-PyTorch ground-truth linear ops PYTORCH_NATIVE_LM_HEAD = "rl_engine.kernels.ops.pytorch.linear.lm_head.NativeLMHeadOp" # WS1 pure-PyTorch ground-truth embedding ops @@ -96,6 +106,8 @@ def __init__(self): OpBackend.PYTORCH_NATIVE, ], "attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_ATTN], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], + "kv_cache_attention": [OpBackend.PYTORCH_NATIVE_KV_CACHE_ATTN], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], @@ -113,6 +125,8 @@ def __init__(self): OpBackend.PYTORCH_ATTN, OpBackend.TRITON_GENERIC, ], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], + "kv_cache_attention": [OpBackend.PYTORCH_NATIVE_KV_CACHE_ATTN], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], @@ -125,6 +139,8 @@ def __init__(self): "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_ATTN], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], + "kv_cache_attention": [OpBackend.PYTORCH_NATIVE_KV_CACHE_ATTN], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000..b9eb598 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,448 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Tests for NativeAttentionOp (ISSUE #108 WS1 ground-truth baseline). + +Standard-softmax attention: out = softmax(Q Kᵀ * scale + masks) @ V, hand-written +(NOT F.scaled_dot_product_attention) so the reduction order is fixed. Like +lm_head this is a *reduction* (over the key dim Skv), so: + + * Axis-B (accuracy): the low-precision ``forward`` path accumulates in the + input dtype and drifts from the fp32 ``forward_fp32`` ground truth. It is + checked with a tolerance relative to the output peak magnitude, not bitwise + -- attention outputs are convex combinations of V rows, so many entries sit + near zero while the accumulated error tracks the reduction length. + * Axis-A (batch invariance): bitwise within a single dtype, but only once the + CPU reduction order is pinned. Multi-threaded CPU GEMM splits the matmul + reduction by the batch dimension, which silently breaks bitwise batch + invariance. ``_single_thread`` fixes the reduction order; it is the local + stand-in for the planned testing/determinism.py::deterministic_context. + +This op covers ONLY the softmax attention; QK-Norm and RoPE are applied before +the call (see the chain test) -- the q,k here are plain synthetic tensors. +""" + +import contextlib +import math + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.attention.standard_attn import NativeAttentionOp +from rl_engine.kernels.registry import kernel_registry + +# Qwen3-8B attention dims (synthetic tensors, no checkpoint). Unlike embedding / +# lm_head (whose multi-GB weight forces shrinking), attention's cost is the +# scores tensor [B, Hq, Sq, Skv], so the *real* head dims are cheap at a SMALL +# (batch, seq) load point and kept real here. Only LARGE (8, 4096) is GPU-only. +_N_HEADS = 32 # Q heads +_N_KV = 8 # KV heads; GQA group g = 32 / 8 = 4 +_HEAD_DIM = 128 # 32 * 128 == 4096 == hidden + +# Axis-B: max abs error as a fraction of the output peak magnitude. Calibrated +# from measured SMALL drift (bf16 ~1% of peak, fp16 ~0.1%) with headroom. +_DTYPE_REL_PEAK = {torch.bfloat16: 3.0e-2, torch.float16: 5.0e-3} + + +def _cpu_fp16_matmul_supported() -> bool: + """Probe whether this CPU backend implements float16 matmul.""" + try: + _ = torch.randn(2, 2, dtype=torch.float16) @ torch.randn(2, 2, dtype=torch.float16) + return True + except RuntimeError: + return False + + +# CPU half-precision matmul is backend/ISA-dependent and may be unimplemented on +# some runners -- gate the fp16 axis so a missing kernel skips rather than fails. +_FP16_IF_CPU_MATMUL_SUPPORTED = pytest.param( + torch.float16, + marks=pytest.mark.skipif( + not _cpu_fp16_matmul_supported(), + reason="CPU float16 matmul unsupported on this backend", + ), +) +_DTYPES_AXIS_B = (torch.bfloat16, _FP16_IF_CPU_MATMUL_SUPPORTED) +_DTYPES_AXIS_A = (torch.float32, torch.bfloat16, _FP16_IF_CPU_MATMUL_SUPPORTED) + + +@contextlib.contextmanager +def _single_thread(): + """Pin CPU GEMM to one thread so the matmul reduction order is batch-independent.""" + prev = torch.get_num_threads() + torch.set_num_threads(1) + try: + yield + finally: + torch.set_num_threads(prev) + + +# Shared helpers -- fixed-seed Generator for determinism / reproducibility. +def _qkv(batch, sq, skv, *, seed, dtype=torch.float32, n_heads=_N_HEADS, n_kv=_N_KV, d=_HEAD_DIM): + """Fixed-seed random q [B,Hq,Sq,D], k/v [B,Hkv,Skv,D] for reproducibility.""" + gen = torch.Generator().manual_seed(seed) + q = torch.randn(batch, n_heads, sq, d, generator=gen, dtype=dtype) + k = torch.randn(batch, n_kv, skv, d, generator=gen, dtype=dtype) + v = torch.randn(batch, n_kv, skv, d, generator=gen, dtype=dtype) + return q, k, v + + +def _ref_softmax_attn(q, k, v, *, causal, scale=None, key_padding_mask=None): + """Independent naive-softmax reference mirroring the contract (GQA, masks). + + Dtype-preserving: pass fp32 tensors for a bitwise fp32 check, or .double() + tensors for a TF32-immune high-precision reference. + """ + qf, kf, vf = q, k, v + Hq, Sq, D = qf.shape[1], qf.shape[2], qf.shape[3] + Hkv, Skv = kf.shape[1], kf.shape[2] + if Hkv != Hq: + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) + s = torch.matmul(qf, kf.transpose(-1, -2)) * ( + scale if scale is not None else 1.0 / math.sqrt(D) + ) + if causal: + m = torch.triu( + torch.ones(Sq, Skv, dtype=torch.bool, device=qf.device), diagonal=Skv - Sq + 1 + ) + s = s.masked_fill(m, float("-inf")) + if key_padding_mask is not None: + s = s.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) + return torch.softmax(s, dim=-1) @ vf + + +# --------------------------------------------------------------------------- # +# fp32 ground-truth correctness +# --------------------------------------------------------------------------- # +# forward_fp32 == the independent naive fp32 reference, bitwise. Both fix the +# same reduction order, so this validates the op's wiring (transpose dims, scale, +# masks) exactly. TF32 is pinned off so the fp32 forward path matches too. +def test_forward_fp32_matches_independent_reference(): + """forward_fp32 (and the fp32 forward path, TF32 off) is bitwise-equal to a + naive fp32 reference.""" + q, k, v = _qkv(2, 16, 16, seed=1) # Qwen3 32/8/128, SMALL prefill + ref = _ref_softmax_attn(q, k, v, causal=True) + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + assert torch.equal(NativeAttentionOp().forward_fp32(q, k, v, causal=True), ref) + assert torch.equal(NativeAttentionOp().forward(q, k, v, causal=True), ref) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + +def test_forward_fp32_ignores_ambient_autocast_and_restores_tf32(): + """forward_fp32 is a strict fp32 reference under ambient autocast/TF32 settings.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=11) + ref = _ref_softmax_attn(q, k, v, causal=True) + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True + try: + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + out = op.forward_fp32(q, k, v, causal=True) + assert out.dtype == torch.float32 + assert torch.equal(out, ref) + assert torch.backends.cuda.matmul.allow_tf32 is True # restored + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs a CUDA GPU to exercise TF32") +def test_forward_fp32_disables_tf32_on_gpu(): + """On a TF32-enabled GPU, forward_fp32 stays true fp32, no worse than a TF32 path.""" + device = torch.device("cuda") + gen = torch.Generator(device=device).manual_seed(21) + q = torch.randn(2, _N_HEADS, 64, _HEAD_DIM, generator=gen, device=device) + k = torch.randn(2, _N_KV, 64, _HEAD_DIM, generator=gen, device=device) + v = torch.randn(2, _N_KV, 64, _HEAD_DIM, generator=gen, device=device) + ref = _ref_softmax_attn(q.double(), k.double(), v.double(), causal=True).float() + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True # hostile ambient setting + try: + strict = NativeAttentionOp().forward_fp32(q, k, v, causal=True) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + peak = ref.abs().max().item() + strict_err = (strict - ref).abs().max().item() + print(f"\n[attention fp32-vs-tf32] strict_err={strict_err:.3g} peak={peak:.3g}") + assert strict_err <= 1.0e-3 * peak # fp32-tight, well under TF32 drift floor + + +# --------------------------------------------------------------------------- # +# Attention-specific correctness (closed-form, independent of the op's code) +# --------------------------------------------------------------------------- # +# With q=k=0 every score is 0, so softmax is uniform over the *visible* keys. +# Under causal masking query i sees keys 0..i, so out[i] == mean(v[0..i]). +def test_causal_uniform_attention_closed_form(): + """Causal masking: with uniform scores, out[i] == mean of keys 0..i.""" + B, H, S, D = 1, 1, 4, 3 + q = torch.zeros(B, H, S, D) + k = torch.zeros(B, H, S, D) + v = torch.arange(S * D, dtype=torch.float32).reshape(1, 1, S, D) # distinct rows + out = NativeAttentionOp().forward_fp32(q, k, v, causal=True) + expected = torch.stack([v[0, 0, : i + 1].mean(dim=0) for i in range(S)]) # cumulative mean + assert torch.allclose(out[0, 0], expected, atol=1e-6) + + +# Decode special case: a single query (Sq=1) with causal offset Skv-1+1=Skv masks +# nothing -> it must see all keys, i.e. equal the non-causal result. +def test_causal_decode_sees_all_keys(): + """Decode (Sq=1): causal masks nothing, so it equals the non-causal result.""" + op = NativeAttentionOp() + gen = torch.Generator().manual_seed(2) + q = torch.randn(2, _N_HEADS, 1, _HEAD_DIM, generator=gen) + k = torch.randn(2, _N_KV, 40, _HEAD_DIM, generator=gen) + v = torch.randn(2, _N_KV, 40, _HEAD_DIM, generator=gen) + assert torch.equal( + op.forward_fp32(q, k, v, causal=True), op.forward_fp32(q, k, v, causal=False) + ) + + +# GQA: 32 Q heads share 8 KV heads (g=4). Output keeps 32 heads, and the result +# matches an independent reference that expands KV with repeat_interleave. +def test_gqa_replication(): + """GQA: output keeps Hq=32 heads and matches the repeat_interleave reference.""" + q, k, v = _qkv(2, 8, 8, seed=3) + out = NativeAttentionOp().forward_fp32(q, k, v, causal=False) + assert out.shape == (2, _N_HEADS, 8, _HEAD_DIM) + assert out.shape[1] == 4 * k.shape[1] # 32 == g * 8 + assert torch.equal(out, _ref_softmax_attn(q, k, v, causal=False)) + + +def test_gqa_requires_divisible_heads(): + """Hq not divisible by Hkv is rejected (no valid GQA grouping).""" + gen = torch.Generator().manual_seed(31) + q = torch.randn(1, 6, 4, _HEAD_DIM, generator=gen) # 6 not divisible by 4 + k = torch.randn(1, 4, 4, _HEAD_DIM, generator=gen) + v = torch.randn(1, 4, 4, _HEAD_DIM, generator=gen) + with pytest.raises(ValueError, match="not divisible"): + NativeAttentionOp().forward_fp32(q, k, v, causal=False) + + +# scale defaults to 1/sqrt(head_dim); an explicit scale (incl. 0.0) is honored. +def test_scale_default_and_explicit(): + """Default scale is 1/sqrt(D); an explicit scale (incl. 0.0) is used verbatim.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=4) + assert torch.equal( + op.forward_fp32(q, k, v, causal=False), + op.forward_fp32(q, k, v, causal=False, scale=1.0 / math.sqrt(_HEAD_DIM)), + ) + # scale=0.0 -> all scores 0 -> uniform attention over all keys (mean of V). + out0 = op.forward_fp32(q, k, v, causal=False, scale=0.0) + assert torch.allclose( + out0, + v.float().repeat_interleave(4, dim=1).mean(dim=2, keepdim=True).expand_as(out0), + atol=1e-6, + ) + + +# key_padding_mask (True=valid): padded key columns get zero weight, so the +# result equals attention computed over only the valid keys. +# +# NOTE: padding changes the softmax reduction width (Skv=10 vs Skv=6). Even +# though masked positions contribute exp(-inf)=0 to the sum, the internal +# reduction order of torch.softmax over a size-10 row vs a size-6 row may +# differ (vectorisation boundaries, intermediate rounding of partial sums), +# so bitwise equality across different Skv is NOT guaranteed in IEEE 754. +# We assert near-equality (atol=2e-6) which validates the masking semantics +# without over-constraining the floating-point reduction path. The observed +# drift is ~1.3e-6 and is platform-sensitive, so the threshold carries headroom. +_PADDING_ATOL = 2.0e-6 + + +def test_key_padding_mask_excludes_padded_keys(): + """key_padding_mask: padded keys get zero weight (≈ attending over valid keys only). + + Not bitwise-equal because the softmax reduction width differs (Skv=10 vs 6); + see comment above for rationale. + """ + op = NativeAttentionOp() + gen = torch.Generator().manual_seed(5) + q = torch.randn(2, _N_HEADS, 6, _HEAD_DIM, generator=gen) + k_valid = torch.randn(2, _N_KV, 6, _HEAD_DIM, generator=gen) + v_valid = torch.randn(2, _N_KV, 6, _HEAD_DIM, generator=gen) + pad_k = torch.randn(2, _N_KV, 4, _HEAD_DIM, generator=gen) # 4 padding columns + pad_v = torch.randn(2, _N_KV, 4, _HEAD_DIM, generator=gen) + k = torch.cat([k_valid, pad_k], dim=2) + v = torch.cat([v_valid, pad_v], dim=2) + mask = torch.zeros(2, 10, dtype=torch.bool) + mask[:, :6] = True # first 6 valid, last 4 padding + + masked = op.forward_fp32(q, k, v, causal=False, key_padding_mask=mask) + valid_only = op.forward_fp32(q, k_valid, v_valid, causal=False) + + diff = (masked - valid_only).abs() + max_err = diff.max().item() + print(f"\n[padding mask] max_abs_err={max_err:.3g} (threshold={_PADDING_ATOL:.1g})") + assert torch.allclose( + masked, valid_only, atol=_PADDING_ATOL, rtol=0.0 + ), f"Padding-masked result diverges from valid-only by {max_err:.3g} > {_PADDING_ATOL}" + + +# A query whose every key is padded out has an all -inf row; naive softmax would +# emit NaN. The op defines such fully-masked rows as 0, keeping outputs and grads +# finite (NaN would poison both and break alignment against downstream kernels). +def test_fully_masked_query_returns_zero_not_nan(): + """All keys padded out -> the query yields 0 (not NaN), and grads stay finite.""" + gen = torch.Generator().manual_seed(9) + q = torch.randn(1, _N_HEADS, 4, _HEAD_DIM, generator=gen, requires_grad=True) + k = torch.randn(1, _N_KV, 4, _HEAD_DIM, generator=gen) + v = torch.randn(1, _N_KV, 4, _HEAD_DIM, generator=gen) + mask = torch.zeros(1, 4, dtype=torch.bool) # all False == every key is padding + + out = NativeAttentionOp().forward_fp32(q, k, v, causal=False, key_padding_mask=mask) + assert torch.isfinite(out).all() + assert torch.equal(out, torch.zeros_like(out)) + + # NaN would propagate through backward; assert the gradient is finite (zero here). + out.sum().backward() + assert torch.isfinite(q.grad).all() + + +# --------------------------------------------------------------------------- # +# Axis-B accuracy +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_B) +def test_dtype_path_accuracy(dtype: torch.dtype): + """Axis-B: the low-precision path drifts from fp32 by a bounded fraction of the output peak.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 16, 16, seed=2) + ref = op.forward_fp32(q, k, v, causal=True) # fp32 ground truth + cand = op.forward(q.to(dtype), k.to(dtype), v.to(dtype), causal=True) + assert cand.dtype == dtype + + err = (cand.float() - ref).abs() + peak = ref.abs().max() + max_abs, mean_abs = err.max().item(), err.mean().item() + print(f"\n[attention {dtype}] max_abs={max_abs:.4g} mean_abs={mean_abs:.4g} peak={peak:.4g}") + assert max_abs <= _DTYPE_REL_PEAK[dtype] * peak.item() + + +def test_output_shape(): + """Output shape is [B, Hq, Sq, D].""" + q, k, v = _qkv(3, 7, 7, seed=3) + out = NativeAttentionOp().forward(q, k, v, causal=True) + assert out.shape == (3, _N_HEADS, 7, _HEAD_DIM) + + +# --------------------------------------------------------------------------- # +# Axis-A batch invariance (bitwise, single-thread reduction order) +# --------------------------------------------------------------------------- # +# A sequence's attention output must not depend on how many other sequences +# share the batch. Compute on the full batch once, then slice -- never compute a +# slice on its own. Requires the pinned single-thread reduction order. +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_batch_invariance_slice(dtype: torch.dtype): + """Axis-A: a sequence's output is bitwise-independent of how many share the batch.""" + op = NativeAttentionOp() + q, k, v = _qkv(8, 16, 16, seed=5, dtype=dtype) + with _single_thread(): + full = op.forward(q, k, v, causal=True) + assert torch.equal(op.forward(q[:1], k[:1], v[:1], causal=True), full[:1]) + assert torch.equal(op.forward(q[3:5], k[3:5], v[3:5], causal=True), full[3:5]) + + +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_batch_invariance_chunked(dtype: torch.dtype): + """Axis-A (chunked): processing the batch in chunks and concatenating == one shot.""" + op = NativeAttentionOp() + q, k, v = _qkv(8, 16, 16, seed=6, dtype=dtype) + with _single_thread(): + full = op.forward(q, k, v, causal=True) + c1 = op.forward(q[:3], k[:3], v[:3], causal=True) + c2 = op.forward(q[3:], k[3:], v[3:], causal=True) + assert torch.equal(torch.cat([c1, c2], dim=0), full) + + +# --------------------------------------------------------------------------- # +# Purity / gradient / registry +# --------------------------------------------------------------------------- # +def test_inputs_not_mutated(): + """Purity: no input tensor is mutated in place.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=7) + qc, kc, vc = q.clone(), k.clone(), v.clone() + mask = torch.ones(2, 8, dtype=torch.bool) + mc = mask.clone() + op.forward(q, k, v, causal=True, key_padding_mask=mask) + op.forward_fp32(q, k, v, causal=True, key_padding_mask=mask) + assert torch.equal(q, qc) and torch.equal(k, kc) and torch.equal(v, vc) + assert torch.equal(mask, mc) + + +def test_gradient_matches_reference(): + """fp32 autograd grads match autograd through the double-precision reference. + + isfinite only rules out NaN/Inf -- it can't tell a correct gradient from a + wrong-but-finite one, and attention's backward (softmax Jacobian + dQ/dK/dV + contractions) is the most error-prone in the stack. Backprop a *random* + cotangent (not .sum(), whose all-ones cotangent collapses the contraction) + and compare q/k/v grads against autograd through the independent + _ref_softmax_attn computed in float64 (TF32-immune high-precision golden). + """ + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=8) + q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True) + + out = op.forward_fp32(q, k, v, causal=True) + gen = torch.Generator().manual_seed(8) + dy = torch.randn(out.shape, generator=gen, dtype=out.dtype) # seeded for reproducibility + out.backward(dy) + + qd, kd, vd = (t.detach().double().requires_grad_(True) for t in (q, k, v)) + _ref_softmax_attn(qd, kd, vd, causal=True).backward(dy.double()) + for t, td in ((q, qd), (k, kd), (v, vd)): + assert t.grad is not None and t.grad.shape == t.shape + torch.testing.assert_close(t.grad, td.grad.float(), rtol=1e-4, atol=1e-4) + + +def test_registry_dispatches_native_attention_op(): + """The registry resolves "attention" to the ground-truth NativeAttentionOp.""" + assert isinstance(kernel_registry.get_op("attention"), NativeAttentionOp) + + +# --------------------------------------------------------------------------- # +# Qwen3-8B LARGE real-scale GPU smoke test +# --------------------------------------------------------------------------- # +# The scores tensor [B=8, Hq=32, Skv=4096, Skv=4096] is ~17 GB in fp32, so the +# LARGE load point is GPU-only and skips without enough memory. SMALL/MEDIUM at +# real head dims already run on CPU above; this validates real prefill scale. +def _enough_gpu_memory(num_bytes: int) -> bool: + if not torch.cuda.is_available(): + return False + try: + free, _ = torch.cuda.mem_get_info() + except RuntimeError: + return False + return free > num_bytes + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs a CUDA GPU") +def test_attention_qwen3_8b_large_real_shape(): + """GPU smoke at LARGE Qwen3-8B prefill (batch=8, seq=4096, 32/8/128).""" + # Runtime memory check (not a collection-time skipif): free memory at + # collection time is not representative on a shared GPU. Naive fp32 attention + # peaks at ~3x the scores tensor (scores + masked copy + softmax probs all + # live transiently), so budget 3x the ~17 GB scores -> ~50 GB peak. This makes + # LARGE an H100-class (H-series nightly) test, skipping on smaller GPUs. + scores_bytes = 8 * _N_HEADS * 4096 * 4096 * 4 # ~17 GB + if not _enough_gpu_memory(scores_bytes * 3): + pytest.skip("not enough free GPU memory for the ~50 GB fp32 LARGE attention peak") + device = torch.device("cuda") + op = NativeAttentionOp() + gen = torch.Generator(device=device).manual_seed(0) + q = torch.randn(8, _N_HEADS, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + k = torch.randn(8, _N_KV, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + v = torch.randn(8, _N_KV, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + out = op.forward_fp32(q, k, v, causal=True) + assert out.shape == (8, _N_HEADS, 4096, _HEAD_DIM) + assert torch.isfinite(out).all() + # Axis-A: compute on full batch, then slice (no per-slice recompute). + assert torch.equal(op.forward_fp32(q[:1], k[:1], v[:1], causal=True), out[:1]) diff --git a/tests/test_kv_cache_attention.py b/tests/test_kv_cache_attention.py new file mode 100644 index 0000000..a39e992 --- /dev/null +++ b/tests/test_kv_cache_attention.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Tests for NativeKVCacheAttnOp (ISSUE #108 WS1 ground-truth baseline). + +KV-cache (decode/incremental) attention: concat the step's new K/V onto the +cache along the seq axis, then run the *same* standard-softmax attention used for +prefill. Because it delegates to NativeAttentionOp, the central guarantees are: + + * Delegation equivalence: kv_cache(q, cache, new) == standard_attn(q, + cat([cache,new])) bitwise -- validates the wiring (cat dim, arg order, + fp32/dtype path pairing). + * Split-point invariance (Axis-A flavor): a token's output is the same whether + computed in one prefill shot or at decode time (its position split into cache + + new), with the new queries taken as the suffix so Sq == S_new. The score + matmul's M dimension differs across splits, so this is near-equal + (allclose, atol=2e-6), not bitwise. + * Prefill<->decode consistency: stepwise decode reproduces full-prefill outputs + up to a small tolerance. Here the softmax reduction *width* differs (step t + reduces over t+1 keys vs the full Skv with future positions masked to -inf), + so -- exactly as with key padding in standard attention -- IEEE 754 does not + guarantee bitwise equality; we assert allclose(atol=1e-6). + * Axis-B accuracy: the low-precision forward path drifts from forward_fp32 and + is checked with a tolerance relative to the output peak. + +This op covers ONLY the attention; QK-Norm and RoPE are applied before the call. +""" + +import contextlib + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.attention.kv_cache import NativeKVCacheAttnOp +from rl_engine.kernels.ops.pytorch.attention.standard_attn import NativeAttentionOp +from rl_engine.kernels.registry import kernel_registry + +# Qwen3-8B attention dims (synthetic tensors, no checkpoint). +_N_HEADS = 32 # Q heads +_N_KV = 8 # KV heads; GQA group g = 32 / 8 = 4 +_HEAD_DIM = 128 # 32 * 128 == 4096 == hidden + +# Axis-B: max abs error as a fraction of the output peak magnitude (same basis as +# the standard-attention test, since this op shares its reduction). +_DTYPE_REL_PEAK = {torch.bfloat16: 3.0e-2, torch.float16: 5.0e-3} + +# Prefill<->decode reduction width differs -> not bitwise; bounded near-equality. +_DECODE_ATOL = 1.0e-6 + +# key_padding_mask compares a softmax over (S_past+S_new) keys against one over the +# valid-only subset, so the reduction widths differ (same situation as the standard +# attention padding test). The drift is ~1.3e-6 and platform-sensitive, so this +# cross-width comparison carries extra headroom over the closed-form decode checks. +_PADDING_ATOL = 2.0e-6 + + +def _cpu_fp16_matmul_supported() -> bool: + """Probe whether this CPU backend implements float16 matmul.""" + try: + _ = torch.randn(2, 2, dtype=torch.float16) @ torch.randn(2, 2, dtype=torch.float16) + return True + except RuntimeError: + return False + + +_FP16_IF_CPU_MATMUL_SUPPORTED = pytest.param( + torch.float16, + marks=pytest.mark.skipif( + not _cpu_fp16_matmul_supported(), + reason="CPU float16 matmul unsupported on this backend", + ), +) +_DTYPES_AXIS_B = (torch.bfloat16, _FP16_IF_CPU_MATMUL_SUPPORTED) + + +@contextlib.contextmanager +def _single_thread(): + """Pin CPU GEMM to one thread so the matmul reduction order is stable.""" + prev = torch.get_num_threads() + torch.set_num_threads(1) + try: + yield + finally: + torch.set_num_threads(prev) + + +def _q(batch, sq, *, seed, dtype=torch.float32): + gen = torch.Generator().manual_seed(seed) + return torch.randn(batch, _N_HEADS, sq, _HEAD_DIM, generator=gen, dtype=dtype) + + +def _kv(batch, s, *, seed, dtype=torch.float32): + """K/V of KV-head count, length s.""" + gen = torch.Generator().manual_seed(seed) + k = torch.randn(batch, _N_KV, s, _HEAD_DIM, generator=gen, dtype=dtype) + v = torch.randn(batch, _N_KV, s, _HEAD_DIM, generator=gen, dtype=dtype) + return k, v + + +# --------------------------------------------------------------------------- # +# Delegation equivalence: kv_cache == standard_attn on the concatenation. +# --------------------------------------------------------------------------- # +def test_forward_fp32_equals_standard_attn_on_concat(): + """forward_fp32 must equal NativeAttentionOp.forward_fp32 on cat([cache,new]).""" + op = NativeKVCacheAttnOp() + ref = NativeAttentionOp() + b, s_past, s_new = 2, 6, 3 + q = _q(b, s_new, seed=1) # Sq == S_new (the new tokens' queries) + k_cache, v_cache = _kv(b, s_past, seed=2) + k_new, v_new = _kv(b, s_new, seed=3) + + with _single_thread(): + got = op.forward_fp32(q, k_cache, v_cache, k_new, v_new, causal=True) + k_full = torch.cat([k_cache, k_new], dim=2) + v_full = torch.cat([v_cache, v_new], dim=2) + want = ref.forward_fp32(q, k_full, v_full, causal=True) + assert torch.equal(got, want) + assert got.shape == (b, _N_HEADS, s_new, _HEAD_DIM) + + +def test_forward_equals_standard_attn_on_concat(): + """forward (input-dtype path) must equal NativeAttentionOp.forward on cat([cache,new]).""" + op = NativeKVCacheAttnOp() + ref = NativeAttentionOp() + b, s_past, s_new = 2, 6, 3 + q = _q(b, s_new, seed=1) + k_cache, v_cache = _kv(b, s_past, seed=2) + k_new, v_new = _kv(b, s_new, seed=3) + + with _single_thread(): + got = op.forward(q, k_cache, v_cache, k_new, v_new, causal=True) + k_full = torch.cat([k_cache, k_new], dim=2) + v_full = torch.cat([v_cache, v_new], dim=2) + want = ref.forward(q, k_full, v_full, causal=True) + assert torch.equal(got, want) + + +# --------------------------------------------------------------------------- # +# Split-point invariance: the cache/new boundary must not change the result. +# --------------------------------------------------------------------------- # +def test_cache_split_point_invariance(): + """A decoded suffix matches the matching slice of full prefill (atol=2e-6). + + This is the soul of KV-cache correctness: a token computed during prefill (in + one shot) vs at decode time (its position split into cache + new) yields the + same attention output. For each split, the *new* queries are the suffix + ``q_full[:, :, split:]`` -- the positions that correspond to ``k_new`` -- so + the op's Sq == S_new contract holds. + + Not bitwise: although the decode call and prefill reduce over the same + Skv = total, the score matmul's M dimension differs (Sq = total - split vs + total), and the GEMM may tile/round that dimension differently, so IEEE 754 + only guarantees near-equality. Observed drift is ~1.4e-6; 2e-6 carries + headroom (cf. key-padding tolerance in standard attention). + """ + op = NativeKVCacheAttnOp() + ref = NativeAttentionOp() + b, total = 2, 8 + q_full = _q(b, total, seed=10) + k_full, v_full = _kv(b, total, seed=11) + split_atol = 2.0e-6 + + with _single_thread(): + prefill = ref.forward_fp32(q_full, k_full, v_full, causal=True) + for split in (0, 1, 4): # all-new ... mostly-cache (split=total -> empty new) + k_cache, k_new = k_full[:, :, :split], k_full[:, :, split:] + v_cache, v_new = v_full[:, :, :split], v_full[:, :, split:] + got = op.forward_fp32(q_full[:, :, split:], k_cache, v_cache, k_new, v_new, causal=True) + want = prefill[:, :, split:] + max_err = (got - want).abs().max().item() + assert torch.allclose( + got, want, atol=split_atol, rtol=0.0 + ), f"split {split} diverges from prefill by {max_err:.3g} > {split_atol}" + + +# --------------------------------------------------------------------------- # +# Prefill <-> decode consistency (reduction width differs -> near-equal). +# --------------------------------------------------------------------------- # +def test_stepwise_decode_matches_full_prefill(): + """Token-by-token decode reproduces full-prefill outputs (atol=1e-6). + + Not bitwise: at step t the softmax reduces over t+1 keys, whereas prefill + reduces over the full Skv with future positions masked to -inf -- a different + reduction width, so IEEE 754 only guarantees near-equality (cf. key padding + in standard attention). + """ + op = NativeKVCacheAttnOp() + ref = NativeAttentionOp() + b, seq = 2, 7 + q_all = _q(b, seq, seed=20) + k_all, v_all = _kv(b, seq, seed=21) + + with _single_thread(): + # Full prefill: one shot over the whole sequence, causal. + prefill = ref.forward_fp32(q_all, k_all, v_all, causal=True) # [B, Hq, seq, D] + + # Stepwise decode: at step t, cache = positions [0,t), new = position t. + for t in range(seq): + q_t = q_all[:, :, t : t + 1] # the query for position t (Sq=1) + k_cache, v_cache = k_all[:, :, :t], v_all[:, :, :t] + k_new, v_new = k_all[:, :, t : t + 1], v_all[:, :, t : t + 1] + decode_t = op.forward_fp32(q_t, k_cache, v_cache, k_new, v_new, causal=True) + max_err = (decode_t - prefill[:, :, t : t + 1]).abs().max().item() + assert torch.allclose( + decode_t, prefill[:, :, t : t + 1], atol=_DECODE_ATOL, rtol=0.0 + ), f"decode step {t} diverges from prefill by {max_err:.3g} > {_DECODE_ATOL}" + + +def test_batch_invariance_slice(): + """Axis-A: a row computed in a batch-of-N is bitwise identical to batch-of-1. + + Each query row reduces over its own keys independently of how many sequences + share the batch, so slicing row i out of the batch-N output must equal running + row i alone. CPU GEMM is pinned to one thread so the matmul reduction order is + batch-independent (multi-threaded GEMM can split by batch and break bitwise). + """ + op = NativeKVCacheAttnOp() + n, s_past, s_new = 4, 6, 2 + q = _q(n, s_new, seed=100) + k_cache, v_cache = _kv(n, s_past, seed=101) + k_new, v_new = _kv(n, s_new, seed=102) + + with _single_thread(): + full = op.forward_fp32(q, k_cache, v_cache, k_new, v_new, causal=True) + for i in range(n): + row = op.forward_fp32( + q[i : i + 1], + k_cache[i : i + 1], + v_cache[i : i + 1], + k_new[i : i + 1], + v_new[i : i + 1], + causal=True, + ) + assert torch.equal(full[i : i + 1], row), f"batch row {i} not invariant" + + +def test_empty_cache_equals_plain_attention(): + """S_past=0 (pure prefill) delegates to plain attention over k_new/v_new.""" + op = NativeKVCacheAttnOp() + ref = NativeAttentionOp() + b, seq = 2, 5 + q = _q(b, seq, seed=30) + k_new, v_new = _kv(b, seq, seed=31) + k_cache = k_new[:, :, :0] # [B, Hkv, 0, D] + v_cache = v_new[:, :, :0] + + with _single_thread(): + got = op.forward_fp32(q, k_cache, v_cache, k_new, v_new, causal=True) + want = ref.forward_fp32(q, k_new, v_new, causal=True) + assert torch.equal(got, want) + + +# --------------------------------------------------------------------------- # +# Decode (Sq=1) sees the whole cache; closed-form uniform check. +# --------------------------------------------------------------------------- # +def test_decode_single_query_uniform_attention(): + """With identical keys, a single decode query attends uniformly -> mean of V.""" + op = NativeKVCacheAttnOp() + b, s_past, s_new = 1, 4, 1 + q = _q(b, s_new, seed=40) + # All keys identical -> all scores equal -> softmax uniform over all S_past+S_new. + k = torch.ones(b, _N_KV, 1, _HEAD_DIM) + k_cache = k.expand(b, _N_KV, s_past, _HEAD_DIM).contiguous() + k_new = k.expand(b, _N_KV, s_new, _HEAD_DIM).contiguous() + gen = torch.Generator().manual_seed(41) + v_cache = torch.randn(b, _N_KV, s_past, _HEAD_DIM, generator=gen) + v_new = torch.randn(b, _N_KV, s_new, _HEAD_DIM, generator=gen) + + out = op.forward_fp32(q, k_cache, v_cache, k_new, v_new, causal=True) + v_full = torch.cat([v_cache, v_new], dim=2) # [B, Hkv, Skv, D] + expected_kv = v_full.mean(dim=2, keepdim=True) # uniform avg over keys + expected = expected_kv.repeat_interleave(_N_HEADS // _N_KV, dim=1) # GQA broadcast + assert torch.allclose(out, expected, atol=1e-5) + + +# --------------------------------------------------------------------------- # +# key_padding_mask over the concatenated length. +# --------------------------------------------------------------------------- # +def test_key_padding_mask_excludes_padded_keys(): + """Padding columns (over S_past+S_new) get zero weight (~ attending valid keys).""" + op = NativeKVCacheAttnOp() + b, s_past, s_new = 2, 5, 3 + q = _q(b, s_new, seed=50) + k_cache, v_cache = _kv(b, s_past, seed=51) + k_new, v_new = _kv(b, s_new, seed=52) + skv = s_past + s_new + + # Mask out the last 2 cached keys for one batch row; non-causal to isolate padding. + mask = torch.ones(b, skv, dtype=torch.bool) + mask[0, s_past - 2 : s_past] = False + + with _single_thread(): + masked = op.forward_fp32( + q, k_cache, v_cache, k_new, v_new, causal=False, key_padding_mask=mask + ) + # Equivalent: drop those keys entirely from the valid row. + keep = mask[0] + k_full = torch.cat([k_cache, k_new], dim=2) + v_full = torch.cat([v_cache, v_new], dim=2) + ref = NativeAttentionOp() + valid_only_row0 = ref.forward_fp32( + q[:1], k_full[:1][:, :, keep], v_full[:1][:, :, keep], causal=False + ) + assert torch.allclose(masked[:1], valid_only_row0, atol=_PADDING_ATOL, rtol=0.0) + + +# --------------------------------------------------------------------------- # +# Axis-B accuracy: low-precision forward vs fp32 ground truth. +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_B) +def test_dtype_path_accuracy(dtype: torch.dtype): + """forward(dtype) tracks forward_fp32 within the per-dtype peak-relative tolerance.""" + op = NativeKVCacheAttnOp() + b, s_past, s_new = 2, 12, 4 + q = _q(b, s_new, seed=60, dtype=dtype) + k_cache, v_cache = _kv(b, s_past, seed=61, dtype=dtype) + k_new, v_new = _kv(b, s_new, seed=62, dtype=dtype) + + got = op.forward(q, k_cache, v_cache, k_new, v_new, causal=True) + ref = op.forward_fp32( + q.float(), + k_cache.float(), + v_cache.float(), + k_new.float(), + v_new.float(), + causal=True, + ) + assert got.dtype == dtype + peak = ref.abs().max().item() + max_err = (got.float() - ref).abs().max().item() + assert ( + max_err <= _DTYPE_REL_PEAK[dtype] * peak + ), f"{dtype}: max_abs_err={max_err:.3g} > {_DTYPE_REL_PEAK[dtype]:.1%} of peak {peak:.3g}" + + +# --------------------------------------------------------------------------- # +# Shape / GQA / purity / registry. +# --------------------------------------------------------------------------- # +def test_output_shape_follows_q(): + op = NativeKVCacheAttnOp() + b, s_past, s_new = 3, 9, 2 + q = _q(b, s_new, seed=70) + k_cache, v_cache = _kv(b, s_past, seed=71) + k_new, v_new = _kv(b, s_new, seed=72) + out = op.forward_fp32(q, k_cache, v_cache, k_new, v_new, causal=True) + assert out.shape == (b, _N_HEADS, s_new, _HEAD_DIM) + + +def test_gqa_requires_divisible_heads(): + """q heads not divisible by KV heads raises (propagated from NativeAttentionOp).""" + op = NativeKVCacheAttnOp() + b = 1 + q = torch.randn(b, 7, 1, _HEAD_DIM) # 7 not divisible by _N_KV=8 + k_cache, v_cache = _kv(b, 3, seed=80) + k_new, v_new = _kv(b, 1, seed=81) + with pytest.raises(ValueError): + op.forward_fp32(q, k_cache, v_cache, k_new, v_new) + + +def test_misaligned_q_length_raises(): + """q must hold exactly the new positions (Sq == S_new); a mismatch raises.""" + op = NativeKVCacheAttnOp() + b, s_past, s_new = 2, 5, 3 + k_cache, v_cache = _kv(b, s_past, seed=85) + k_new, v_new = _kv(b, s_new, seed=86) + q = _q(b, s_new + 1, seed=87) # Sq=4 != S_new=3 + with pytest.raises(ValueError): + op.forward_fp32(q, k_cache, v_cache, k_new, v_new, causal=True) + with pytest.raises(ValueError): + op.forward(q, k_cache, v_cache, k_new, v_new, causal=True) + + +def test_inputs_not_mutated(): + """Pure op: cache/new tensors are not modified in place.""" + op = NativeKVCacheAttnOp() + b, s_past, s_new = 2, 5, 2 + q = _q(b, s_new, seed=90) + k_cache, v_cache = _kv(b, s_past, seed=91) + k_new, v_new = _kv(b, s_new, seed=92) + snapshots = [t.clone() for t in (q, k_cache, v_cache, k_new, v_new)] + op.forward_fp32(q, k_cache, v_cache, k_new, v_new, causal=True) + for orig, snap in zip((q, k_cache, v_cache, k_new, v_new), snapshots): + assert torch.equal(orig, snap) + + +def test_registry_dispatches_kv_cache_attn_op(): + """The registry resolves "kv_cache_attention" to NativeKVCacheAttnOp.""" + assert isinstance(kernel_registry.get_op("kv_cache_attention"), NativeKVCacheAttnOp)