diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 755f426..5c2a2a0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,6 +68,10 @@ jobs: python -m pytest rl_engine/tests/test_dispatch.py -v PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs + - name: Run Attention Ground-Truth Tests (CPU-safe) + run: | + python -m pytest tests/test_attention.py -v -k "not large and not gpu" + docs: runs-on: ubuntu-latest steps: diff --git a/docs/.nav.yml b/docs/.nav.yml index 9321713..1274acb 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -12,6 +12,7 @@ nav: - Operators: - operators/README.md - operators/activation.md + - operators/attention.md - operators/fused-logp.md - operators/linear-logp.md - operators/grpo-loss.md diff --git a/docs/operators/README.md b/docs/operators/README.md index eb2f5de..aa73396 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -19,6 +19,7 @@ Every operator page should include: ## Current Pages - [SiLU / SwiGLU Activation](activation.md) +- [Standard Attention](attention.md) - [Fused LogP](fused-logp.md) - [Fused Linear LogP](linear-logp.md) - [GRPO Loss](grpo-loss.md) diff --git a/docs/operators/attention.md b/docs/operators/attention.md new file mode 100644 index 0000000..e3cb5f9 --- /dev/null +++ b/docs/operators/attention.md @@ -0,0 +1,159 @@ +# Standard Softmax Attention + +The attention operator is the reduction core of the Qwen3/Llama transformer block. It is a +**WS1 ground-truth reference** (issue #108): a pure-PyTorch, fp32-accumulating definition of +the "correct answer" that downstream fused CUDA/Triton attention kernels (FlashAttention and +friends) are validated against. + +- **`NativeAttentionOp`**: `out = softmax(Q Kᵀ · scale + masks) @ V` — a hand-written naive + softmax with a **fixed reduction order**, deliberately **not** + `F.scaled_dot_product_attention` / flash / mem-efficient attention (whose reduction order + is unspecified and would break the batch-invariance contract). + +This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before* +the call (see the chain), so the `q`, `k` passed in are already normalized and rotated. + +```text +q --\ +k ----softmax(QKᵀ/√d + mask)·V--> out +v --/ +``` + +## Entry Point +```python +from rl_engine.kernels.registry import kernel_registry + +attn = kernel_registry.get_op("attention") + +# Prefill: Sq == Skv ; Decode: Sq < Skv (one/few new queries against the full cache) +out = attn(q, k, v, causal=True) # [B, 32, Sq, 128] +out = attn(q, k, v, causal=True, scale=1.0 / 128 ** 0.5) # explicit scale +out = attn(q, k, v, causal=False, key_padding_mask=mask) # mask: [B, Skv] bool, True = keep +``` + +The op exposes the WS1 dual-path contract: + +- `forward(...)` — computes in the input dtype, returns the input dtype (Axis-B accuracy + candidate / dtype-behavior path). +- `forward_fp32(...)` — upcasts to fp32, accumulates in fp32, returns fp32 (the ground-truth + golden path). It disables autocast and TF32 so it stays a true fp32 reference regardless of + the caller's ambient precision context. + +> **Not the same as `"attn"`.** `kernel_registry.get_op("attn")` resolves to the production +> SDPA fallback (`PYTORCH_ATTN`); this ground-truth op is registered separately under +> `"attention"` (`PYTORCH_NATIVE_ATTENTION`). The two do not overlap. + +## Backends + +| Backend | Wrapper | Native symbol | Status | +| --- | --- | --- | --- | +| PyTorch fallback | `NativeAttentionOp` | None | fp32 ground-truth reference; CPU and any GPU. | +| CUDA / ROCm / Triton | — | — | Planned: downstream fused attention kernels validate against this reference. | + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `q` | `[B, Hq, Sq, D]` | float (fp16/bf16/fp32) | Qwen3-8B: `Hq=32`, `D=128`. | +| `k` | `[B, Hkv, Skv, D]` | float | Qwen3-8B: `Hkv=8` (GQA). `Hq` must be divisible by `Hkv`. | +| `v` | `[B, Hkv, Skv, D]` | float | Same head/seq layout as `k`. | +| `causal` | — | bool (kw, default `True`) | Upper-triangular mask at offset `Skv - Sq + 1`. | +| `scale` | — | float or `None` (kw) | `None` → `1/sqrt(D)` = `1/√128`. An explicit value (incl. `0.0`) is used verbatim. | +| `key_padding_mask` | `[B, Skv]` | bool or `None` (kw) | `True` = valid / keep, `False` = padding → that key column set to `-inf`. | +| output | `[B, Hq, Sq, D]` | `forward`: input dtype · `forward_fp32`: float32 | Heads precede seq (`[B, H, S, D]`). | + +**GQA** (`Hq=32`, `Hkv=8`, group `g=4`): each KV head is replicated `g` times with +`repeat_interleave(g, dim=1)` (not `repeat`), so query head `h` maps to KV head `h // g`. + +**Causal offset** `Skv - Sq + 1` anchors the queries to the end of the sequence, so a single +expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees +the whole cache). + +Pure function — no randomness, no in-place mutation; device follows the inputs. `forward(...)` +preserves the input dtype, while `forward_fp32(...)` always returns fp32. Masks are built on +the inputs' device. + +## Dispatch Behavior + +`kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On +`cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op +(`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to this op. Calling it (`__call__` -> +`forward(...)`) computes in the input dtype; `forward_fp32(...)` is the explicit fp32 golden +path. When fused attention kernels land, they are prepended to the priority list and the native +op becomes the fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is +a separate dispatch chain and is unaffected. + +## Accuracy + +Reference semantics (`forward_fp32`, fp32 accumulation, TF32/autocast disabled): + +```python +qf, kf, vf = q.float(), k.float(), v.float() +if Hkv != Hq: # GQA: replicate KV, 32 Q / 8 KV, r = 4 + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) +scale = scale if scale is not None else 1.0 / math.sqrt(D) # D=128 → 1/√128 +scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] +if causal: # offset covers prefill + decode + m = torch.triu(torch.ones(Sq, Skv, dtype=torch.bool), diagonal=Skv - Sq + 1) + scores = scores.masked_fill(m, float("-inf")) +if key_padding_mask is not None: # True = keep ; False columns → -inf + scores = scores.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) +probs = torch.softmax(scores, dim=-1) # subtracts per-row max internally +out = torch.matmul(probs, vf) # [B, Hq, Sq, D] +``` + +- **Ground truth**: `forward_fp32` always accumulates in and returns fp32, with TF32 and + autocast disabled so it is not silently downcast by the caller's ambient context. +- **Dtype path**: `forward` runs the same math in the input dtype, so low-precision reductions + over the key dimension drift from the fp32 reference — Axis-B accuracy therefore uses a + tolerance, not bitwise equality. +- **Axis A — batch invariance**: each query row reduces over the keys independently of how many + sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`) + across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax + reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width + (e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small + tolerance (`atol=2e-6`), not bitwise, in IEEE 754. +- **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction` + row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak): + + | dtype | max_abs / peak | threshold (rel-peak) | + | --- | --- | --- | + | bfloat16 | ~0.56 % | 3 % | + | float16 | ~0.07 % | 0.5 % | + +## Performance Notes + +Reference operator — no fused kernel or benchmark yet. Downstream fused attention kernels carry +their own benchmarks and are measured against this reference for correctness. At the LARGE +Qwen3-8B load point (`B=8`, `Skv=4096`, `Hq=32`) the fp32 scores tensor alone is ~17 GB and the +naive path peaks at ~3× that, so the LARGE smoke test is GPU-only and skips without enough +memory. + +## Tests + +```bash +python -m pytest tests/test_attention.py -v +``` + +Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile +autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard, +scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch +invariance (slice + chunked, bitwise; padding is near-equality only, see below), input purity, +gradient flow, registry dispatch, and a +GPU-only LARGE Qwen3-8B real-shape smoke test. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/attention/standard_attn.py` +- `rl_engine/kernels/registry.py` +- `tests/test_attention.py` + +## Known Limitations + +- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). +- `Hq` must be divisible by `Hkv` (raises `ValueError` otherwise). +- The naive path materializes the full `[B, Hq, Sq, Skv]` scores tensor — no query-chunking, + so the LARGE load point is memory-heavy and GPU-only. +- Covers softmax attention only; QK-Norm and RoPE are applied before the call. diff --git a/rl_engine/kernels/ops/pytorch/attention/standard_attn.py b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py new file mode 100644 index 0000000..804efb6 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py @@ -0,0 +1,191 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import math +from contextlib import contextmanager, nullcontext +from typing import Optional + +import torch + + +class NativeAttentionOp: + """ + Pure PyTorch native standard-softmax attention reference. + out = softmax(Q Kᵀ * scale + masks) @ V + + Hand-written naive softmax -- deliberately NOT + ``F.scaled_dot_product_attention`` / flash / mem-efficient attention, whose + reduction order is unspecified and would break the batch-invariance (Axis-A) + contract. This op defines the *correct answer* the fused kernels align to. + + Qwen3-8B shapes: q ``[B, 32, Sq, 128]``, k/v ``[B, 8, Skv, 128]`` (GQA group + g = 32/8 = 4), scale = 1/sqrt(head_dim) = 1/sqrt(128). Heads precede seq in + the layout. This op covers ONLY the softmax attention; QK-Norm and RoPE are + applied *before* the call (see the chain test) -- the q,k passed in are + already normalized and rotated. + + This is a reduction over the key dimension (Skv): the low-precision + ``forward`` path accumulates in the input dtype and therefore drifts from + the fp32 ``forward_fp32`` ground truth, so Axis-B accuracy uses a tolerance + (``torch.allclose``), not bitwise equality. Axis-A batch invariance still + holds bitwise within a single dtype (each query row reduces over the keys + independently of how many sequences share the batch). + + Masking conventions: + * causal=True -> upper-triangular -inf at diagonal Skv-Sq+1, valid for + both prefill (Sq==Skv) and decode (Sq padded + key columns set to -inf (matches reference_ops.py: True=keep, False=mask). + """ + + def __init__(self) -> None: + """No state; the op is a pure function over (q, k, v, ...).""" + + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Alias for ``forward`` so the op is callable like a module.""" + return self.forward(q, k, v, causal=causal, scale=scale, key_padding_mask=key_padding_mask) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Canonical entry: attend in the input dtype, output the input dtype. + This is the dtype-behavior path used as the Axis-B accuracy candidate. + """ + return self._attention( + q, + k, + v, + causal=causal, + scale=scale, + key_padding_mask=key_padding_mask, + compute_dtype=q.dtype, + output_dtype=q.dtype, + ) + + def forward_fp32( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Ground truth: upcast to fp32, accumulate in fp32, force fp32 output. + + The whole score->softmax->value path is wrapped to disable autocast and + TF32 so this stays a true fp32 reference regardless of the caller's + ambient precision context. + """ + return self._attention( + q, + k, + v, + causal=causal, + scale=scale, + key_padding_mask=key_padding_mask, + compute_dtype=torch.float32, + output_dtype=torch.float32, + strict_fp32=True, + ) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + @staticmethod + def _attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool, + scale: Optional[float], + key_padding_mask: Optional[torch.Tensor], + compute_dtype: torch.dtype, + output_dtype: torch.dtype, + strict_fp32: bool = False, + ) -> torch.Tensor: + """Core softmax attention: cast to ``compute_dtype``, score, mask, softmax, + weighted-sum over V, cast out. ``strict_fp32`` disables autocast/TF32 so the + fp32 reference is not silently downcast by the caller's ambient context. + """ + Hq, Sq, D = q.shape[1], q.shape[2], q.shape[3] + Hkv, Skv = k.shape[1], k.shape[2] + if Hq % Hkv != 0: + raise ValueError(f"Hq={Hq} not divisible by Hkv={Hkv} (GQA group)") + + ctx = NativeAttentionOp._strict_fp32_math(q.device.type) if strict_fp32 else nullcontext() + with ctx: + qf = q.to(compute_dtype) + kf = k.to(compute_dtype) + vf = v.to(compute_dtype) + + # GQA: replicate each KV head g=Hq//Hkv times (Qwen3: 32/8 -> 4). + # repeat_interleave (not repeat) keeps each KV head's copies adjacent + # so query head h maps to KV head h // g. + if Hkv != Hq: + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) + + # scale defaults to 1/sqrt(head_dim); `is not None` so an explicit 0.0 is kept. + scale = scale if scale is not None else (1.0 / math.sqrt(D)) + scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] + + # Causal: offset Skv-Sq+1 covers prefill (Sq==Skv) and decode (Sq -inf. + if key_padding_mask is not None: + pad = ~key_padding_mask + scores = scores.masked_fill(pad[:, None, None, :], float("-inf")) + + probs = torch.softmax(scores, dim=-1) # subtracts row max internally + if key_padding_mask is not None: + # A query whose every valid key is padded out has an all -inf row; + # softmax would emit NaN. Define such fully-masked rows as 0 (no key + # contributes), keeping outputs/grads finite. Only key_padding_mask can + # fully mask a row -- causal alone always leaves a query its own key -- + # so this guard lives in the padding branch and the no-pad path is + # untouched. Row-independent, so Axis-A batch invariance still holds. + all_masked = ~torch.isfinite(scores).any(dim=-1, keepdim=True) + probs = torch.where(all_masked, torch.zeros_like(probs), probs) + out = torch.matmul(probs, vf) # [B, Hq, Sq, D] + return out.to(output_dtype) + + @staticmethod + @contextmanager + def _strict_fp32_math(device_type: str): + """Disable autocast and TF32 for a true fp32 path, restoring state after.""" + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + with torch.autocast(device_type=device_type, enabled=False): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 5e63d84..c7c0b00 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -59,6 +59,11 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): PYTORCH_NATIVE_SILU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSiLUOp" PYTORCH_NATIVE_SWIGLU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSwiGLUOp" + # WS1 pure-PyTorch ground-truth attention reference (hand-written fp32 softmax). + # Distinct from PYTORCH_ATTN above, which is the production SDPA fallback. + PYTORCH_NATIVE_ATTENTION = ( + "rl_engine.kernels.ops.pytorch.attention.standard_attn.NativeAttentionOp" + ) # WS1 pure-PyTorch ground-truth linear ops PYTORCH_NATIVE_LM_HEAD = "rl_engine.kernels.ops.pytorch.linear.lm_head.NativeLMHeadOp" # WS1 pure-PyTorch ground-truth embedding ops @@ -96,6 +101,7 @@ def __init__(self): OpBackend.PYTORCH_NATIVE, ], "attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_ATTN], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], @@ -113,6 +119,7 @@ def __init__(self): OpBackend.PYTORCH_ATTN, OpBackend.TRITON_GENERIC, ], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], @@ -125,6 +132,7 @@ def __init__(self): "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_ATTN], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000..ffb5558 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,478 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Tests for NativeAttentionOp (ISSUE #108 WS1 ground-truth baseline). + +Standard-softmax attention: out = softmax(Q Kᵀ * scale + masks) @ V, hand-written +(NOT F.scaled_dot_product_attention) so the reduction order is fixed. Like +lm_head this is a *reduction* (over the key dim Skv), so: + + * Axis-B (accuracy): the low-precision ``forward`` path accumulates in the + input dtype and drifts from the fp32 ``forward_fp32`` ground truth. It is + checked with a tolerance relative to the output peak magnitude, not bitwise + -- attention outputs are convex combinations of V rows, so many entries sit + near zero while the accumulated error tracks the reduction length. + * Axis-A (batch invariance): bitwise within a single dtype, but only once the + CPU reduction order is pinned. Multi-threaded CPU GEMM splits the matmul + reduction by the batch dimension, which silently breaks bitwise batch + invariance. ``_single_thread`` fixes the reduction order; it is the local + stand-in for the planned testing/determinism.py::deterministic_context. + +This op covers ONLY the softmax attention; QK-Norm and RoPE are applied before +the call (see the chain test) -- the q,k here are plain synthetic tensors. +""" + +import contextlib +import math + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.attention.standard_attn import NativeAttentionOp +from rl_engine.kernels.registry import kernel_registry + +# Qwen3-8B attention dims (synthetic tensors, no checkpoint). Unlike embedding / +# lm_head (whose multi-GB weight forces shrinking), attention's cost is the +# scores tensor [B, Hq, Sq, Skv], so the *real* head dims are cheap at a SMALL +# (batch, seq) load point and kept real here. Only LARGE (8, 4096) is GPU-only. +_N_HEADS = 32 # Q heads +_N_KV = 8 # KV heads; GQA group g = 32 / 8 = 4 +_HEAD_DIM = 128 # 32 * 128 == 4096 == hidden + +# Axis-B: max abs error as a fraction of the output peak magnitude. Calibrated +# from measured SMALL drift (bf16 ~1% of peak, fp16 ~0.1%) with headroom. +_DTYPE_REL_PEAK = {torch.bfloat16: 3.0e-2, torch.float16: 5.0e-3} + + +def _cpu_fp16_matmul_supported() -> bool: + """Probe whether this CPU backend implements float16 matmul.""" + try: + _ = torch.randn(2, 2, dtype=torch.float16) @ torch.randn(2, 2, dtype=torch.float16) + return True + except RuntimeError: + return False + + +# CPU half-precision matmul is backend/ISA-dependent and may be unimplemented on +# some runners -- gate the fp16 axis so a missing kernel skips rather than fails. +_FP16_IF_CPU_MATMUL_SUPPORTED = pytest.param( + torch.float16, + marks=pytest.mark.skipif( + not _cpu_fp16_matmul_supported(), + reason="CPU float16 matmul unsupported on this backend", + ), +) +_DTYPES_AXIS_B = (torch.bfloat16, _FP16_IF_CPU_MATMUL_SUPPORTED) +_DTYPES_AXIS_A = (torch.float32, torch.bfloat16, _FP16_IF_CPU_MATMUL_SUPPORTED) + + +@contextlib.contextmanager +def _single_thread(): + """Pin CPU GEMM to one thread so the matmul reduction order is batch-independent.""" + prev = torch.get_num_threads() + torch.set_num_threads(1) + try: + yield + finally: + torch.set_num_threads(prev) + + +# Shared helpers -- fixed-seed Generator for determinism / reproducibility. +def _qkv(batch, sq, skv, *, seed, dtype=torch.float32, n_heads=_N_HEADS, n_kv=_N_KV, d=_HEAD_DIM): + """Fixed-seed random q [B,Hq,Sq,D], k/v [B,Hkv,Skv,D] for reproducibility.""" + gen = torch.Generator().manual_seed(seed) + q = torch.randn(batch, n_heads, sq, d, generator=gen, dtype=dtype) + k = torch.randn(batch, n_kv, skv, d, generator=gen, dtype=dtype) + v = torch.randn(batch, n_kv, skv, d, generator=gen, dtype=dtype) + return q, k, v + + +def _ref_softmax_attn(q, k, v, *, causal, scale=None, key_padding_mask=None): + """Independent naive-softmax reference mirroring the contract (GQA, masks). + + Dtype-preserving: pass fp32 tensors for a bitwise fp32 check, or .double() + tensors for a TF32-immune high-precision reference. + """ + qf, kf, vf = q, k, v + Hq, Sq, D = qf.shape[1], qf.shape[2], qf.shape[3] + Hkv, Skv = kf.shape[1], kf.shape[2] + if Hkv != Hq: + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) + s = torch.matmul(qf, kf.transpose(-1, -2)) * ( + scale if scale is not None else 1.0 / math.sqrt(D) + ) + if causal: + m = torch.triu( + torch.ones(Sq, Skv, dtype=torch.bool, device=qf.device), diagonal=Skv - Sq + 1 + ) + s = s.masked_fill(m, float("-inf")) + if key_padding_mask is not None: + s = s.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) + return torch.softmax(s, dim=-1) @ vf + + +# --------------------------------------------------------------------------- # +# fp32 ground-truth correctness +# --------------------------------------------------------------------------- # +# forward_fp32 == the independent naive fp32 reference, bitwise. Both fix the +# same reduction order, so this validates the op's wiring (transpose dims, scale, +# masks) exactly. TF32 is pinned off so the fp32 forward path matches too. +def test_forward_fp32_matches_independent_reference(): + """forward_fp32 (and the fp32 forward path, TF32 off) is bitwise-equal to a + naive fp32 reference.""" + q, k, v = _qkv(2, 16, 16, seed=1) # Qwen3 32/8/128, SMALL prefill + ref = _ref_softmax_attn(q, k, v, causal=True) + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + assert torch.equal(NativeAttentionOp().forward_fp32(q, k, v, causal=True), ref) + assert torch.equal(NativeAttentionOp().forward(q, k, v, causal=True), ref) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + +def test_forward_fp32_ignores_ambient_autocast_and_restores_tf32(): + """forward_fp32 is a strict fp32 reference under ambient autocast/TF32 settings.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=11) + ref = _ref_softmax_attn(q, k, v, causal=True) + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True + try: + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + out = op.forward_fp32(q, k, v, causal=True) + assert out.dtype == torch.float32 + assert torch.equal(out, ref) + assert torch.backends.cuda.matmul.allow_tf32 is True # restored + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs a CUDA GPU to exercise TF32") +def test_forward_fp32_disables_tf32_on_gpu(): + """On a TF32-enabled GPU, forward_fp32 stays true fp32, no worse than a TF32 path.""" + device = torch.device("cuda") + gen = torch.Generator(device=device).manual_seed(21) + q = torch.randn(2, _N_HEADS, 64, _HEAD_DIM, generator=gen, device=device) + k = torch.randn(2, _N_KV, 64, _HEAD_DIM, generator=gen, device=device) + v = torch.randn(2, _N_KV, 64, _HEAD_DIM, generator=gen, device=device) + ref = _ref_softmax_attn(q.double(), k.double(), v.double(), causal=True).float() + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True # hostile ambient setting + try: + strict = NativeAttentionOp().forward_fp32(q, k, v, causal=True) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + peak = ref.abs().max().item() + strict_err = (strict - ref).abs().max().item() + print(f"\n[attention fp32-vs-tf32] strict_err={strict_err:.3g} peak={peak:.3g}") + assert strict_err <= 1.0e-3 * peak # fp32-tight, well under TF32 drift floor + + +# --------------------------------------------------------------------------- # +# Attention-specific correctness (closed-form, independent of the op's code) +# --------------------------------------------------------------------------- # +# With q=k=0 every score is 0, so softmax is uniform over the *visible* keys. +# Under causal masking query i sees keys 0..i, so out[i] == mean(v[0..i]). +def test_causal_uniform_attention_closed_form(): + """Causal masking: with uniform scores, out[i] == mean of keys 0..i.""" + B, H, S, D = 1, 1, 4, 3 + q = torch.zeros(B, H, S, D) + k = torch.zeros(B, H, S, D) + v = torch.arange(S * D, dtype=torch.float32).reshape(1, 1, S, D) # distinct rows + out = NativeAttentionOp().forward_fp32(q, k, v, causal=True) + expected = torch.stack([v[0, 0, : i + 1].mean(dim=0) for i in range(S)]) # cumulative mean + assert torch.allclose(out[0, 0], expected, atol=1e-6) + + +# Decode special case: a single query (Sq=1) with causal offset Skv-1+1=Skv masks +# nothing -> it must see all keys, i.e. equal the non-causal result. +def test_causal_decode_sees_all_keys(): + """Decode (Sq=1): causal masks nothing, so it equals the non-causal result.""" + op = NativeAttentionOp() + gen = torch.Generator().manual_seed(2) + q = torch.randn(2, _N_HEADS, 1, _HEAD_DIM, generator=gen) + k = torch.randn(2, _N_KV, 40, _HEAD_DIM, generator=gen) + v = torch.randn(2, _N_KV, 40, _HEAD_DIM, generator=gen) + assert torch.equal( + op.forward_fp32(q, k, v, causal=True), op.forward_fp32(q, k, v, causal=False) + ) + + +# GQA: 32 Q heads share 8 KV heads (g=4). Output keeps 32 heads, and the result +# matches an independent reference that expands KV with repeat_interleave. +def test_gqa_replication(): + """GQA: output keeps Hq=32 heads and matches the repeat_interleave reference.""" + q, k, v = _qkv(2, 8, 8, seed=3) + out = NativeAttentionOp().forward_fp32(q, k, v, causal=False) + assert out.shape == (2, _N_HEADS, 8, _HEAD_DIM) + assert out.shape[1] == 4 * k.shape[1] # 32 == g * 8 + assert torch.equal(out, _ref_softmax_attn(q, k, v, causal=False)) + + +def test_gqa_requires_divisible_heads(): + """Hq not divisible by Hkv is rejected (no valid GQA grouping).""" + gen = torch.Generator().manual_seed(31) + q = torch.randn(1, 6, 4, _HEAD_DIM, generator=gen) # 6 not divisible by 4 + k = torch.randn(1, 4, 4, _HEAD_DIM, generator=gen) + v = torch.randn(1, 4, 4, _HEAD_DIM, generator=gen) + with pytest.raises(ValueError, match="not divisible"): + NativeAttentionOp().forward_fp32(q, k, v, causal=False) + + +# scale defaults to 1/sqrt(head_dim); an explicit scale (incl. 0.0) is honored. +def test_scale_default_and_explicit(): + """Default scale is 1/sqrt(D); an explicit scale (incl. 0.0) is used verbatim.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=4) + assert torch.equal( + op.forward_fp32(q, k, v, causal=False), + op.forward_fp32(q, k, v, causal=False, scale=1.0 / math.sqrt(_HEAD_DIM)), + ) + # scale=0.0 -> all scores 0 -> uniform attention over all keys (mean of V). + out0 = op.forward_fp32(q, k, v, causal=False, scale=0.0) + assert torch.allclose( + out0, + v.float().repeat_interleave(4, dim=1).mean(dim=2, keepdim=True).expand_as(out0), + atol=1e-6, + ) + + +# key_padding_mask (True=valid): padded key columns get zero weight, so the +# result equals attention computed over only the valid keys. +# +# NOTE: padding changes the softmax reduction width (Skv=10 vs Skv=6). Even +# though masked positions contribute exp(-inf)=0 to the sum, the internal +# reduction order of torch.softmax over a size-10 row vs a size-6 row may +# differ (vectorisation boundaries, intermediate rounding of partial sums), +# so bitwise equality across different Skv is NOT guaranteed in IEEE 754. +# We assert near-equality (atol=2e-6) which validates the masking semantics +# without over-constraining the floating-point reduction path. The observed +# drift is ~1.3e-6 and is platform-sensitive, so the threshold carries headroom. +_PADDING_ATOL = 2.0e-6 + + +def test_key_padding_mask_excludes_padded_keys(): + """key_padding_mask: padded keys get zero weight (≈ attending over valid keys only). + + Not bitwise-equal because the softmax reduction width differs (Skv=10 vs 6); + see comment above for rationale. + """ + op = NativeAttentionOp() + gen = torch.Generator().manual_seed(5) + q = torch.randn(2, _N_HEADS, 6, _HEAD_DIM, generator=gen) + k_valid = torch.randn(2, _N_KV, 6, _HEAD_DIM, generator=gen) + v_valid = torch.randn(2, _N_KV, 6, _HEAD_DIM, generator=gen) + pad_k = torch.randn(2, _N_KV, 4, _HEAD_DIM, generator=gen) # 4 padding columns + pad_v = torch.randn(2, _N_KV, 4, _HEAD_DIM, generator=gen) + k = torch.cat([k_valid, pad_k], dim=2) + v = torch.cat([v_valid, pad_v], dim=2) + mask = torch.zeros(2, 10, dtype=torch.bool) + mask[:, :6] = True # first 6 valid, last 4 padding + + masked = op.forward_fp32(q, k, v, causal=False, key_padding_mask=mask) + valid_only = op.forward_fp32(q, k_valid, v_valid, causal=False) + + diff = (masked - valid_only).abs() + max_err = diff.max().item() + print(f"\n[padding mask] max_abs_err={max_err:.3g} (threshold={_PADDING_ATOL:.1g})") + assert torch.allclose( + masked, valid_only, atol=_PADDING_ATOL, rtol=0.0 + ), f"Padding-masked result diverges from valid-only by {max_err:.3g} > {_PADDING_ATOL}" + + +# A query whose every key is padded out has an all -inf row; naive softmax would +# emit NaN. The op defines such fully-masked rows as 0, keeping outputs and grads +# finite (NaN would poison both and break alignment against downstream kernels). +def test_fully_masked_query_returns_zero_not_nan(): + """All keys padded out -> the query yields 0 (not NaN), and grads stay finite.""" + gen = torch.Generator().manual_seed(9) + q = torch.randn(1, _N_HEADS, 4, _HEAD_DIM, generator=gen, requires_grad=True) + k = torch.randn(1, _N_KV, 4, _HEAD_DIM, generator=gen) + v = torch.randn(1, _N_KV, 4, _HEAD_DIM, generator=gen) + mask = torch.zeros(1, 4, dtype=torch.bool) # all False == every key is padding + + out = NativeAttentionOp().forward_fp32(q, k, v, causal=False, key_padding_mask=mask) + assert torch.isfinite(out).all() + assert torch.equal(out, torch.zeros_like(out)) + + # NaN would propagate through backward; assert the gradient is finite (zero here). + out.sum().backward() + assert torch.isfinite(q.grad).all() + + +# --------------------------------------------------------------------------- # +# Axis-B accuracy +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_B) +def test_dtype_path_accuracy(dtype: torch.dtype): + """Axis-B: the low-precision path drifts from fp32 by a bounded fraction of the output peak.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 16, 16, seed=2) + ref = op.forward_fp32(q, k, v, causal=True) # fp32 ground truth + cand = op.forward(q.to(dtype), k.to(dtype), v.to(dtype), causal=True) + assert cand.dtype == dtype + + err = (cand.float() - ref).abs() + peak = ref.abs().max() + max_abs, mean_abs = err.max().item(), err.mean().item() + print(f"\n[attention {dtype}] max_abs={max_abs:.4g} mean_abs={mean_abs:.4g} peak={peak:.4g}") + assert max_abs <= _DTYPE_REL_PEAK[dtype] * peak.item() + + +def test_output_shape(): + """Output shape is [B, Hq, Sq, D].""" + q, k, v = _qkv(3, 7, 7, seed=3) + out = NativeAttentionOp().forward(q, k, v, causal=True) + assert out.shape == (3, _N_HEADS, 7, _HEAD_DIM) + + +# --------------------------------------------------------------------------- # +# Axis-A batch invariance (bitwise, single-thread reduction order) +# --------------------------------------------------------------------------- # +# A sequence's attention output must not depend on how many other sequences +# share the batch. Compute on the full batch once, then slice -- never compute a +# slice on its own. Requires the pinned single-thread reduction order. +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_batch_invariance_slice(dtype: torch.dtype): + """Axis-A: a sequence's output is bitwise-independent of how many share the batch.""" + op = NativeAttentionOp() + q, k, v = _qkv(8, 16, 16, seed=5, dtype=dtype) + with _single_thread(): + full = op.forward(q, k, v, causal=True) + assert torch.equal(op.forward(q[:1], k[:1], v[:1], causal=True), full[:1]) + assert torch.equal(op.forward(q[3:5], k[3:5], v[3:5], causal=True), full[3:5]) + + +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_batch_invariance_chunked(dtype: torch.dtype): + """Axis-A (chunked): processing the batch in chunks and concatenating == one shot.""" + op = NativeAttentionOp() + q, k, v = _qkv(8, 16, 16, seed=6, dtype=dtype) + with _single_thread(): + full = op.forward(q, k, v, causal=True) + c1 = op.forward(q[:3], k[:3], v[:3], causal=True) + c2 = op.forward(q[3:], k[3:], v[3:], causal=True) + assert torch.equal(torch.cat([c1, c2], dim=0), full) + + +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_backward_batch_invariance_slice(dtype: torch.dtype): + """Axis-A (backward): q/k/v grads are bitwise-independent of batch size.""" + op = NativeAttentionOp() + q_full, k_full, v_full = _qkv(8, 16, 16, seed=10, dtype=dtype) + + q_full.requires_grad_(True) + k_full.requires_grad_(True) + v_full.requires_grad_(True) + + gen = torch.Generator().manual_seed(10) + dy = torch.randn(8, _N_HEADS, 16, _HEAD_DIM, generator=gen, dtype=dtype) + + with _single_thread(): + out_full = op.forward(q_full, k_full, v_full, causal=True) + out_full.backward(dy) + + q_slice = q_full[:1].detach().clone().requires_grad_(True) + k_slice = k_full[:1].detach().clone().requires_grad_(True) + v_slice = v_full[:1].detach().clone().requires_grad_(True) + + with _single_thread(): + out_slice = op.forward(q_slice, k_slice, v_slice, causal=True) + out_slice.backward(dy[:1]) + + assert torch.equal(q_slice.grad, q_full.grad[:1]) + assert torch.equal(k_slice.grad, k_full.grad[:1]) + assert torch.equal(v_slice.grad, v_full.grad[:1]) + + +# --------------------------------------------------------------------------- # +# Purity / gradient / registry +# --------------------------------------------------------------------------- # +def test_inputs_not_mutated(): + """Purity: no input tensor is mutated in place.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=7) + qc, kc, vc = q.clone(), k.clone(), v.clone() + mask = torch.ones(2, 8, dtype=torch.bool) + mc = mask.clone() + op.forward(q, k, v, causal=True, key_padding_mask=mask) + op.forward_fp32(q, k, v, causal=True, key_padding_mask=mask) + assert torch.equal(q, qc) and torch.equal(k, kc) and torch.equal(v, vc) + assert torch.equal(mask, mc) + + +def test_gradient_matches_reference(): + """fp32 autograd grads match autograd through the double-precision reference. + + isfinite only rules out NaN/Inf -- it can't tell a correct gradient from a + wrong-but-finite one, and attention's backward (softmax Jacobian + dQ/dK/dV + contractions) is the most error-prone in the stack. Backprop a *random* + cotangent (not .sum(), whose all-ones cotangent collapses the contraction) + and compare q/k/v grads against autograd through the independent + _ref_softmax_attn computed in float64 (TF32-immune high-precision golden). + """ + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=8) + q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True) + + out = op.forward_fp32(q, k, v, causal=True) + gen = torch.Generator().manual_seed(8) + dy = torch.randn(out.shape, generator=gen, dtype=out.dtype) # seeded for reproducibility + out.backward(dy) + + qd, kd, vd = (t.detach().double().requires_grad_(True) for t in (q, k, v)) + _ref_softmax_attn(qd, kd, vd, causal=True).backward(dy.double()) + for t, td in ((q, qd), (k, kd), (v, vd)): + assert t.grad is not None and t.grad.shape == t.shape + torch.testing.assert_close(t.grad, td.grad.float(), rtol=1e-4, atol=1e-4) + + +def test_registry_dispatches_native_attention_op(): + """The registry resolves "attention" to the ground-truth NativeAttentionOp.""" + assert isinstance(kernel_registry.get_op("attention"), NativeAttentionOp) + + +# --------------------------------------------------------------------------- # +# Qwen3-8B LARGE real-scale GPU smoke test +# --------------------------------------------------------------------------- # +# The scores tensor [B=8, Hq=32, Skv=4096, Skv=4096] is ~17 GB in fp32, so the +# LARGE load point is GPU-only and skips without enough memory. SMALL/MEDIUM at +# real head dims already run on CPU above; this validates real prefill scale. +def _enough_gpu_memory(num_bytes: int) -> bool: + if not torch.cuda.is_available(): + return False + try: + free, _ = torch.cuda.mem_get_info() + except RuntimeError: + return False + return free > num_bytes + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs a CUDA GPU") +def test_attention_qwen3_8b_large_real_shape(): + """GPU smoke at LARGE Qwen3-8B prefill (batch=8, seq=4096, 32/8/128).""" + # Runtime memory check (not a collection-time skipif): free memory at + # collection time is not representative on a shared GPU. Naive fp32 attention + # peaks at ~3x the scores tensor (scores + masked copy + softmax probs all + # live transiently), so budget 3x the ~17 GB scores -> ~50 GB peak. This makes + # LARGE an H100-class (H-series nightly) test, skipping on smaller GPUs. + scores_bytes = 8 * _N_HEADS * 4096 * 4096 * 4 # ~17 GB + if not _enough_gpu_memory(scores_bytes * 3): + pytest.skip("not enough free GPU memory for the ~50 GB fp32 LARGE attention peak") + device = torch.device("cuda") + op = NativeAttentionOp() + gen = torch.Generator(device=device).manual_seed(0) + q = torch.randn(8, _N_HEADS, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + k = torch.randn(8, _N_KV, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + v = torch.randn(8, _N_KV, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + out = op.forward_fp32(q, k, v, causal=True) + assert out.shape == (8, _N_HEADS, 4096, _HEAD_DIM) + assert torch.isfinite(out).all() + # Axis-A: compute on full batch, then slice (no per-slice recompute). + assert torch.equal(op.forward_fp32(q[:1], k[:1], v[:1], causal=True), out[:1])