From 444eb3b64dc1533cae6be3bee2ccd44ad2978338 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Wed, 24 Jun 2026 15:19:08 +0800 Subject: [PATCH 1/7] feat(ws1): add NativeAttentionOp pure-PyTorch standard-softmax reference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WS1 ground-truth attention op for issue #108 (Qwen3-8B GQA attention): - NativeAttentionOp: out = softmax(Q Kᵀ * scale + masks) @ V, a hand-written naive softmax (NOT F.scaled_dot_product_attention / flash) so the reduction order is fixed for the batch-invariance contract. GQA 32/8 via repeat_interleave, causal offset Skv-Sq+1 (prefill + decode), key_padding_mask (True=valid), scale default 1/sqrt(128). Exposes the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path); forward_fp32 disables TF32/autocast for a strict fp32 reference. Pure function, fp32 accumulation. - register PYTORCH_NATIVE_ATTENTION in OpBackend and the cuda/rocm/cpu priority maps under op_type "attention" (distinct from the production "attn" / PYTORCH_ATTN SDPA fallback) - tests/test_attention.py: forward_fp32 vs independent fp32 reference, closed-form causal/decode, GQA replication + divisibility guard, scale, key-padding, dtype-path accuracy (Axis-B), Axis-A batch invariance (slice + chunked + padding), purity, gradient flow, registry dispatch, GPU-only LARGE Qwen3-8B smoke - docs/operators/attention.md + nav/index wiring --- docs/.nav.yml | 1 + docs/operators/README.md | 1 + docs/operators/attention.md | 153 +++++++ .../ops/pytorch/attention/standard_attn.py | 182 ++++++++ rl_engine/kernels/registry.py | 9 + tests/test_attention.py | 391 ++++++++++++++++++ 6 files changed, 737 insertions(+) create mode 100644 docs/operators/attention.md create mode 100644 rl_engine/kernels/ops/pytorch/attention/standard_attn.py create mode 100644 tests/test_attention.py diff --git a/docs/.nav.yml b/docs/.nav.yml index e9ebaf0..8669e30 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -10,6 +10,7 @@ nav: - getting_started/faq.md - Operators: - operators/README.md + - operators/attention.md - operators/fused-logp.md - operators/linear-logp.md - operators/grpo-loss.md diff --git a/docs/operators/README.md b/docs/operators/README.md index c4eae60..4858d04 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -18,6 +18,7 @@ Every operator page should include: ## Current Pages +- [Standard Attention](attention.md) - [Fused LogP](fused-logp.md) - [Fused Linear LogP](linear-logp.md) - [GRPO Loss](grpo-loss.md) diff --git a/docs/operators/attention.md b/docs/operators/attention.md new file mode 100644 index 0000000..77feef0 --- /dev/null +++ b/docs/operators/attention.md @@ -0,0 +1,153 @@ +# Standard Softmax Attention + +The attention operator is the reduction core of the Qwen3/Llama transformer block. It is a +**WS1 ground-truth reference** (issue #108): a pure-PyTorch, fp32-accumulating definition of +the "correct answer" that downstream fused CUDA/Triton attention kernels (FlashAttention and +friends) are validated against. + +- **`NativeAttentionOp`**: `out = softmax(Q Kᵀ · scale + masks) @ V` — a hand-written naive + softmax with a **fixed reduction order**, deliberately **not** + `F.scaled_dot_product_attention` / flash / mem-efficient attention (whose reduction order + is unspecified and would break the batch-invariance contract). + +This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before* +the call (see the chain), so the `q`, `k` passed in are already normalized and rotated. + +``` +q --\ +k ----softmax(QKᵀ/√d + mask)·V--> out +v --/ +``` + +## Entry Point +```python +from rl_engine.kernels.registry import kernel_registry + +attn = kernel_registry.get_op("attention") + +# Prefill: Sq == Skv ; Decode: Sq < Skv (one/few new queries against the full cache) +out = attn(q, k, v, causal=True) # [B, 32, Sq, 128] +out = attn(q, k, v, causal=True, scale=1.0 / 128 ** 0.5) # explicit scale +out = attn(q, k, v, causal=False, key_padding_mask=mask) # mask: [B, Skv] bool, True = keep +``` + +The op exposes the WS1 dual-path contract: + +- `forward(...)` — computes in the input dtype, returns the input dtype (Axis-B accuracy + candidate / dtype-behavior path). +- `forward_fp32(...)` — upcasts to fp32, accumulates in fp32, returns fp32 (the ground-truth + golden path). It disables autocast and TF32 so it stays a true fp32 reference regardless of + the caller's ambient precision context. + +> **Not the same as `"attn"`.** `kernel_registry.get_op("attn")` resolves to the production +> SDPA fallback (`PYTORCH_ATTN`); this ground-truth op is registered separately under +> `"attention"` (`PYTORCH_NATIVE_ATTENTION`). The two do not overlap. + +## Backends + +| Backend | Wrapper | Native symbol | Status | +| --- | --- | --- | --- | +| PyTorch fallback | `NativeAttentionOp` | None | fp32 ground-truth reference; CPU and any GPU. | +| CUDA / ROCm / Triton | — | — | Planned: downstream fused attention kernels validate against this reference. | + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `q` | `[B, Hq, Sq, D]` | float (fp16/bf16/fp32) | Qwen3-8B: `Hq=32`, `D=128`. | +| `k` | `[B, Hkv, Skv, D]` | float | Qwen3-8B: `Hkv=8` (GQA). `Hq` must be divisible by `Hkv`. | +| `v` | `[B, Hkv, Skv, D]` | float | Same head/seq layout as `k`. | +| `causal` | — | bool (kw, default `True`) | Upper-triangular mask at offset `Skv - Sq + 1`. | +| `scale` | — | float or `None` (kw) | `None` → `1/sqrt(D)` = `1/√128`. An explicit value (incl. `0.0`) is used verbatim. | +| `key_padding_mask` | `[B, Skv]` | bool or `None` (kw) | `True` = valid / keep, `False` = padding → that key column set to `-inf`. | +| output | `[B, Hq, Sq, D]` | `forward`: input dtype · `forward_fp32`: float32 | Heads precede seq (`[B, H, S, D]`). | + +**GQA** (`Hq=32`, `Hkv=8`, group `g=4`): each KV head is replicated `g` times with +`repeat_interleave(g, dim=1)` (not `repeat`), so query head `h` maps to KV head `h // g`. + +**Causal offset** `Skv - Sq + 1` anchors the queries to the end of the sequence, so a single +expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees +the whole cache). + +Pure function — no randomness, no in-place mutation; device and original dtype follow the +inputs. Masks are built on the inputs' device. + +## Dispatch Behavior + +`kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On +`cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op +(`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to the fp32 reference. When fused +attention kernels land, they are prepended to the priority list and the native op becomes the +fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is +a separate dispatch chain and is unaffected. + +## Accuracy + +Reference semantics (`forward_fp32`, fp32 accumulation, TF32/autocast disabled): + +```python +qf, kf, vf = q.float(), k.float(), v.float() +if Hkv != Hq: # GQA: replicate KV, 32 Q / 8 KV, r = 4 + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) +scale = scale if scale is not None else 1.0 / math.sqrt(D) # D=128 → 1/√128 +scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] +if causal: # offset covers prefill + decode + m = torch.triu(torch.ones(Sq, Skv, dtype=torch.bool), diagonal=Skv - Sq + 1) + scores = scores.masked_fill(m, float("-inf")) +if key_padding_mask is not None: # True = keep ; False columns → -inf + scores = scores.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) +probs = torch.softmax(scores, dim=-1) # subtracts per-row max internally +out = torch.matmul(probs, vf) # [B, Hq, Sq, D] +``` + +- **Ground truth**: `forward_fp32` always accumulates in and returns fp32, with TF32 and + autocast disabled so it is not silently downcast by the caller's ambient context. +- **Dtype path**: `forward` runs the same math in the input dtype, so low-precision reductions + over the key dimension drift from the fp32 reference — Axis-B accuracy therefore uses a + tolerance, not bitwise equality. +- **Axis A — batch invariance**: each query row reduces over the keys independently of how many + sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`) + across batch slicing, padding, **and chunked** (chunked-prefill) configurations. +- **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction` + row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak): + + | dtype | max_abs / peak | threshold (rel-peak) | + | --- | --- | --- | + | bfloat16 | ~0.56 % | 3 % | + | float16 | ~0.07 % | 0.5 % | + +## Performance Notes + +Reference operator — no fused kernel or benchmark yet. Downstream fused attention kernels carry +their own benchmarks and are measured against this reference for correctness. At the LARGE +Qwen3-8B load point (`B=8`, `Skv=4096`, `Hq=32`) the fp32 scores tensor alone is ~17 GB and the +naive path peaks at ~3× that, so the LARGE smoke test is GPU-only and skips without enough +memory. + +## Tests + +```bash +python -m pytest tests/test_attention.py -v +``` + +Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile +autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard, +scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch +invariance (slice + chunked + padding), input purity, gradient flow, registry dispatch, and a +GPU-only LARGE Qwen3-8B real-shape smoke test. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/attention/standard_attn.py` +- `rl_engine/kernels/registry.py` +- `tests/test_attention.py` + +## Known Limitations + +- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). +- `Hq` must be divisible by `Hkv` (raises `ValueError` otherwise). +- The naive path materializes the full `[B, Hq, Sq, Skv]` scores tensor — no query-chunking, + so the LARGE load point is memory-heavy and GPU-only. +- Covers softmax attention only; QK-Norm and RoPE are applied before the call. diff --git a/rl_engine/kernels/ops/pytorch/attention/standard_attn.py b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py new file mode 100644 index 0000000..8c3c94d --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import math +from contextlib import contextmanager, nullcontext +from typing import Optional + +import torch + + +class NativeAttentionOp: + """ + Pure PyTorch native standard-softmax attention reference. + out = softmax(Q Kᵀ * scale + masks) @ V + + Hand-written naive softmax -- deliberately NOT + ``F.scaled_dot_product_attention`` / flash / mem-efficient attention, whose + reduction order is unspecified and would break the batch-invariance (Axis-A) + contract. This op defines the *correct answer* the fused kernels align to. + + Qwen3-8B shapes: q ``[B, 32, Sq, 128]``, k/v ``[B, 8, Skv, 128]`` (GQA group + g = 32/8 = 4), scale = 1/sqrt(head_dim) = 1/sqrt(128). Heads precede seq in + the layout. This op covers ONLY the softmax attention; QK-Norm and RoPE are + applied *before* the call (see the chain test) -- the q,k passed in are + already normalized and rotated. + + This is a reduction over the key dimension (Skv): the low-precision + ``forward`` path accumulates in the input dtype and therefore drifts from + the fp32 ``forward_fp32`` ground truth, so Axis-B accuracy uses a tolerance + (``torch.allclose``), not bitwise equality. Axis-A batch invariance still + holds bitwise within a single dtype (each query row reduces over the keys + independently of how many sequences share the batch). + + Masking conventions: + * causal=True -> upper-triangular -inf at diagonal Skv-Sq+1, valid for + both prefill (Sq==Skv) and decode (Sq padded + key columns set to -inf (matches reference_ops.py: True=keep, False=mask). + """ + + def __init__(self) -> None: + """No state; the op is a pure function over (q, k, v, ...).""" + + def __call__( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Alias for ``forward`` so the op is callable like a module.""" + return self.forward(q, k, v, causal=causal, scale=scale, key_padding_mask=key_padding_mask) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Canonical entry: attend in the input dtype, output the input dtype. + This is the dtype-behavior path used as the Axis-B accuracy candidate. + """ + return self._attention( + q, + k, + v, + causal=causal, + scale=scale, + key_padding_mask=key_padding_mask, + compute_dtype=q.dtype, + output_dtype=q.dtype, + ) + + def forward_fp32( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool = True, + scale: Optional[float] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Ground truth: upcast to fp32, accumulate in fp32, force fp32 output. + + The whole score->softmax->value path is wrapped to disable autocast and + TF32 so this stays a true fp32 reference regardless of the caller's + ambient precision context. + """ + return self._attention( + q, + k, + v, + causal=causal, + scale=scale, + key_padding_mask=key_padding_mask, + compute_dtype=torch.float32, + output_dtype=torch.float32, + strict_fp32=True, + ) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + @staticmethod + def _attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + causal: bool, + scale: Optional[float], + key_padding_mask: Optional[torch.Tensor], + compute_dtype: torch.dtype, + output_dtype: torch.dtype, + strict_fp32: bool = False, + ) -> torch.Tensor: + """Core softmax attention: cast to ``compute_dtype``, score, mask, softmax, + weighted-sum over V, cast out. ``strict_fp32`` disables autocast/TF32 so the + fp32 reference is not silently downcast by the caller's ambient context. + """ + Hq, Sq, D = q.shape[1], q.shape[2], q.shape[3] + Hkv, Skv = k.shape[1], k.shape[2] + if Hq % Hkv != 0: + raise ValueError(f"Hq={Hq} not divisible by Hkv={Hkv} (GQA group)") + + ctx = NativeAttentionOp._strict_fp32_math(q.device.type) if strict_fp32 else nullcontext() + with ctx: + qf = q.to(compute_dtype) + kf = k.to(compute_dtype) + vf = v.to(compute_dtype) + + # GQA: replicate each KV head g=Hq//Hkv times (Qwen3: 32/8 -> 4). + # repeat_interleave (not repeat) keeps each KV head's copies adjacent + # so query head h maps to KV head h // g. + if Hkv != Hq: + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) + + # scale defaults to 1/sqrt(head_dim); `is not None` so an explicit 0.0 is kept. + scale = scale if scale is not None else (1.0 / math.sqrt(D)) + scores = torch.matmul(qf, kf.transpose(-1, -2)) * scale # [B, Hq, Sq, Skv] + + # Causal: offset Skv-Sq+1 covers prefill (Sq==Skv) and decode (Sq -inf. + if key_padding_mask is not None: + pad = ~key_padding_mask + scores = scores.masked_fill(pad[:, None, None, :], float("-inf")) + + probs = torch.softmax(scores, dim=-1) # subtracts row max internally + out = torch.matmul(probs, vf) # [B, Hq, Sq, D] + return out.to(output_dtype) + + @staticmethod + @contextmanager + def _strict_fp32_math(device_type: str): + """Disable autocast and TF32 for a true fp32 path, restoring state after.""" + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + with torch.autocast(device_type=device_type, enabled=False): + yield + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 6780157..caa5a45 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -54,6 +54,12 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): PYTORCH_ATTN = "rl_engine.kernels.ops.pytorch.attention.NativeAttentionOp" PYTORCH_NATIVE = "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp" + # WS1 pure-PyTorch ground-truth attention reference (hand-written fp32 softmax). + # Distinct from PYTORCH_ATTN above, which is the production SDPA fallback. + PYTORCH_NATIVE_ATTENTION = ( + "rl_engine.kernels.ops.pytorch.attention.standard_attn.NativeAttentionOp" + ) + class KernelRegistry: """ @@ -86,6 +92,7 @@ def __init__(self): OpBackend.PYTORCH_NATIVE, ], "attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_ATTN], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], @@ -98,6 +105,7 @@ def __init__(self): OpBackend.PYTORCH_ATTN, OpBackend.TRITON_GENERIC, ], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], @@ -105,6 +113,7 @@ def __init__(self): "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_ATTN], + "attention": [OpBackend.PYTORCH_NATIVE_ATTENTION], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], diff --git a/tests/test_attention.py b/tests/test_attention.py new file mode 100644 index 0000000..7c46684 --- /dev/null +++ b/tests/test_attention.py @@ -0,0 +1,391 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Tests for NativeAttentionOp (ISSUE #108 WS1 ground-truth baseline). + +Standard-softmax attention: out = softmax(Q Kᵀ * scale + masks) @ V, hand-written +(NOT F.scaled_dot_product_attention) so the reduction order is fixed. Like +lm_head this is a *reduction* (over the key dim Skv), so: + + * Axis-B (accuracy): the low-precision ``forward`` path accumulates in the + input dtype and drifts from the fp32 ``forward_fp32`` ground truth. It is + checked with a tolerance relative to the output peak magnitude, not bitwise + -- attention outputs are convex combinations of V rows, so many entries sit + near zero while the accumulated error tracks the reduction length. + * Axis-A (batch invariance): bitwise within a single dtype, but only once the + CPU reduction order is pinned. Multi-threaded CPU GEMM splits the matmul + reduction by the batch dimension, which silently breaks bitwise batch + invariance. ``_single_thread`` fixes the reduction order; it is the local + stand-in for the planned testing/determinism.py::deterministic_context. + +This op covers ONLY the softmax attention; QK-Norm and RoPE are applied before +the call (see the chain test) -- the q,k here are plain synthetic tensors. +""" + +import contextlib +import math + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.attention.standard_attn import NativeAttentionOp +from rl_engine.kernels.registry import kernel_registry + +# Qwen3-8B attention dims (synthetic tensors, no checkpoint). Unlike embedding / +# lm_head (whose multi-GB weight forces shrinking), attention's cost is the +# scores tensor [B, Hq, Sq, Skv], so the *real* head dims are cheap at a SMALL +# (batch, seq) load point and kept real here. Only LARGE (8, 4096) is GPU-only. +_N_HEADS = 32 # Q heads +_N_KV = 8 # KV heads; GQA group g = 32 / 8 = 4 +_HEAD_DIM = 128 # 32 * 128 == 4096 == hidden + +# Axis-B: max abs error as a fraction of the output peak magnitude. Calibrated +# from measured SMALL drift (bf16 ~1% of peak, fp16 ~0.1%) with headroom. +_DTYPE_REL_PEAK = {torch.bfloat16: 3.0e-2, torch.float16: 5.0e-3} + + +def _cpu_fp16_matmul_supported() -> bool: + """Probe whether this CPU backend implements float16 matmul.""" + try: + _ = torch.randn(2, 2, dtype=torch.float16) @ torch.randn(2, 2, dtype=torch.float16) + return True + except RuntimeError: + return False + + +# CPU half-precision matmul is backend/ISA-dependent and may be unimplemented on +# some runners -- gate the fp16 axis so a missing kernel skips rather than fails. +_FP16_IF_CPU_MATMUL_SUPPORTED = pytest.param( + torch.float16, + marks=pytest.mark.skipif( + not _cpu_fp16_matmul_supported(), + reason="CPU float16 matmul unsupported on this backend", + ), +) +_DTYPES_AXIS_B = (torch.bfloat16, _FP16_IF_CPU_MATMUL_SUPPORTED) +_DTYPES_AXIS_A = (torch.float32, torch.bfloat16, _FP16_IF_CPU_MATMUL_SUPPORTED) + + +@contextlib.contextmanager +def _single_thread(): + """Pin CPU GEMM to one thread so the matmul reduction order is batch-independent.""" + prev = torch.get_num_threads() + torch.set_num_threads(1) + try: + yield + finally: + torch.set_num_threads(prev) + + +# Shared helpers -- fixed-seed Generator for determinism / reproducibility. +def _qkv(batch, sq, skv, *, seed, dtype=torch.float32, n_heads=_N_HEADS, n_kv=_N_KV, d=_HEAD_DIM): + """Fixed-seed random q [B,Hq,Sq,D], k/v [B,Hkv,Skv,D] for reproducibility.""" + gen = torch.Generator().manual_seed(seed) + q = torch.randn(batch, n_heads, sq, d, generator=gen, dtype=dtype) + k = torch.randn(batch, n_kv, skv, d, generator=gen, dtype=dtype) + v = torch.randn(batch, n_kv, skv, d, generator=gen, dtype=dtype) + return q, k, v + + +def _ref_softmax_attn(q, k, v, *, causal, scale=None, key_padding_mask=None): + """Independent naive-softmax reference mirroring the contract (GQA, masks). + + Dtype-preserving: pass fp32 tensors for a bitwise fp32 check, or .double() + tensors for a TF32-immune high-precision reference. + """ + qf, kf, vf = q, k, v + Hq, Sq, D = qf.shape[1], qf.shape[2], qf.shape[3] + Hkv, Skv = kf.shape[1], kf.shape[2] + if Hkv != Hq: + r = Hq // Hkv + kf = kf.repeat_interleave(r, dim=1) + vf = vf.repeat_interleave(r, dim=1) + s = torch.matmul(qf, kf.transpose(-1, -2)) * ( + scale if scale is not None else 1.0 / math.sqrt(D) + ) + if causal: + m = torch.triu( + torch.ones(Sq, Skv, dtype=torch.bool, device=qf.device), diagonal=Skv - Sq + 1 + ) + s = s.masked_fill(m, float("-inf")) + if key_padding_mask is not None: + s = s.masked_fill(~key_padding_mask[:, None, None, :], float("-inf")) + return torch.softmax(s, dim=-1) @ vf + + +# --------------------------------------------------------------------------- # +# fp32 ground-truth correctness +# --------------------------------------------------------------------------- # +# forward_fp32 == the independent naive fp32 reference, bitwise. Both fix the +# same reduction order, so this validates the op's wiring (transpose dims, scale, +# masks) exactly. TF32 is pinned off so the fp32 forward path matches too. +def test_forward_fp32_matches_independent_reference(): + """forward_fp32 (and the fp32 forward path, TF32 off) is bitwise-equal to a + naive fp32 reference.""" + q, k, v = _qkv(2, 16, 16, seed=1) # Qwen3 32/8/128, SMALL prefill + ref = _ref_softmax_attn(q, k, v, causal=True) + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = False + try: + assert torch.equal(NativeAttentionOp().forward_fp32(q, k, v, causal=True), ref) + assert torch.equal(NativeAttentionOp().forward(q, k, v, causal=True), ref) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + +def test_forward_fp32_ignores_ambient_autocast_and_restores_tf32(): + """forward_fp32 is a strict fp32 reference under ambient autocast/TF32 settings.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=11) + ref = _ref_softmax_attn(q, k, v, causal=True) + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True + try: + with torch.autocast(device_type="cpu", dtype=torch.bfloat16): + out = op.forward_fp32(q, k, v, causal=True) + assert out.dtype == torch.float32 + assert torch.equal(out, ref) + assert torch.backends.cuda.matmul.allow_tf32 is True # restored + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs a CUDA GPU to exercise TF32") +def test_forward_fp32_disables_tf32_on_gpu(): + """On a TF32-enabled GPU, forward_fp32 stays true fp32, no worse than a TF32 path.""" + device = torch.device("cuda") + gen = torch.Generator(device=device).manual_seed(21) + q = torch.randn(2, _N_HEADS, 64, _HEAD_DIM, generator=gen, device=device) + k = torch.randn(2, _N_KV, 64, _HEAD_DIM, generator=gen, device=device) + v = torch.randn(2, _N_KV, 64, _HEAD_DIM, generator=gen, device=device) + ref = _ref_softmax_attn(q.double(), k.double(), v.double(), causal=True).float() + + prev_tf32 = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True # hostile ambient setting + try: + strict = NativeAttentionOp().forward_fp32(q, k, v, causal=True) + finally: + torch.backends.cuda.matmul.allow_tf32 = prev_tf32 + + peak = ref.abs().max().item() + strict_err = (strict - ref).abs().max().item() + print(f"\n[attention fp32-vs-tf32] strict_err={strict_err:.3g} peak={peak:.3g}") + assert strict_err <= 1.0e-3 * peak # fp32-tight, well under TF32 drift floor + + +# --------------------------------------------------------------------------- # +# Attention-specific correctness (closed-form, independent of the op's code) +# --------------------------------------------------------------------------- # +# With q=k=0 every score is 0, so softmax is uniform over the *visible* keys. +# Under causal masking query i sees keys 0..i, so out[i] == mean(v[0..i]). +def test_causal_uniform_attention_closed_form(): + """Causal masking: with uniform scores, out[i] == mean of keys 0..i.""" + B, H, S, D = 1, 1, 4, 3 + q = torch.zeros(B, H, S, D) + k = torch.zeros(B, H, S, D) + v = torch.arange(S * D, dtype=torch.float32).reshape(1, 1, S, D) # distinct rows + out = NativeAttentionOp().forward_fp32(q, k, v, causal=True) + expected = torch.stack([v[0, 0, : i + 1].mean(dim=0) for i in range(S)]) # cumulative mean + assert torch.allclose(out[0, 0], expected, atol=1e-6) + + +# Decode special case: a single query (Sq=1) with causal offset Skv-1+1=Skv masks +# nothing -> it must see all keys, i.e. equal the non-causal result. +def test_causal_decode_sees_all_keys(): + """Decode (Sq=1): causal masks nothing, so it equals the non-causal result.""" + op = NativeAttentionOp() + gen = torch.Generator().manual_seed(2) + q = torch.randn(2, _N_HEADS, 1, _HEAD_DIM, generator=gen) + k = torch.randn(2, _N_KV, 40, _HEAD_DIM, generator=gen) + v = torch.randn(2, _N_KV, 40, _HEAD_DIM, generator=gen) + assert torch.equal( + op.forward_fp32(q, k, v, causal=True), op.forward_fp32(q, k, v, causal=False) + ) + + +# GQA: 32 Q heads share 8 KV heads (g=4). Output keeps 32 heads, and the result +# matches an independent reference that expands KV with repeat_interleave. +def test_gqa_replication(): + """GQA: output keeps Hq=32 heads and matches the repeat_interleave reference.""" + q, k, v = _qkv(2, 8, 8, seed=3) + out = NativeAttentionOp().forward_fp32(q, k, v, causal=False) + assert out.shape == (2, _N_HEADS, 8, _HEAD_DIM) + assert out.shape[1] == 4 * k.shape[1] # 32 == g * 8 + assert torch.equal(out, _ref_softmax_attn(q, k, v, causal=False)) + + +def test_gqa_requires_divisible_heads(): + """Hq not divisible by Hkv is rejected (no valid GQA grouping).""" + gen = torch.Generator().manual_seed(31) + q = torch.randn(1, 6, 4, _HEAD_DIM, generator=gen) # 6 not divisible by 4 + k = torch.randn(1, 4, 4, _HEAD_DIM, generator=gen) + v = torch.randn(1, 4, 4, _HEAD_DIM, generator=gen) + with pytest.raises(ValueError, match="not divisible"): + NativeAttentionOp().forward_fp32(q, k, v, causal=False) + + +# scale defaults to 1/sqrt(head_dim); an explicit scale (incl. 0.0) is honored. +def test_scale_default_and_explicit(): + """Default scale is 1/sqrt(D); an explicit scale (incl. 0.0) is used verbatim.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=4) + assert torch.equal( + op.forward_fp32(q, k, v, causal=False), + op.forward_fp32(q, k, v, causal=False, scale=1.0 / math.sqrt(_HEAD_DIM)), + ) + # scale=0.0 -> all scores 0 -> uniform attention over all keys (mean of V). + out0 = op.forward_fp32(q, k, v, causal=False, scale=0.0) + assert torch.allclose( + out0, + v.float().repeat_interleave(4, dim=1).mean(dim=2, keepdim=True).expand_as(out0), + atol=1e-6, + ) + + +# key_padding_mask (True=valid): padded key columns get zero weight, so the +# result equals attention computed over only the valid keys. +def test_key_padding_mask_excludes_padded_keys(): + """key_padding_mask: padded keys get zero weight (== attending over valid keys only).""" + op = NativeAttentionOp() + gen = torch.Generator().manual_seed(5) + q = torch.randn(2, _N_HEADS, 6, _HEAD_DIM, generator=gen) + k_valid = torch.randn(2, _N_KV, 6, _HEAD_DIM, generator=gen) + v_valid = torch.randn(2, _N_KV, 6, _HEAD_DIM, generator=gen) + pad_k = torch.randn(2, _N_KV, 4, _HEAD_DIM, generator=gen) # 4 padding columns + pad_v = torch.randn(2, _N_KV, 4, _HEAD_DIM, generator=gen) + k = torch.cat([k_valid, pad_k], dim=2) + v = torch.cat([v_valid, pad_v], dim=2) + mask = torch.zeros(2, 10, dtype=torch.bool) + mask[:, :6] = True # first 6 valid, last 4 padding + + masked = op.forward_fp32(q, k, v, causal=False, key_padding_mask=mask) + valid_only = op.forward_fp32(q, k_valid, v_valid, causal=False) + assert torch.equal(masked, valid_only) + + +# --------------------------------------------------------------------------- # +# Axis-B accuracy +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_B) +def test_dtype_path_accuracy(dtype: torch.dtype): + """Axis-B: the low-precision path drifts from fp32 by a bounded fraction of the output peak.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 16, 16, seed=2) + ref = op.forward_fp32(q, k, v, causal=True) # fp32 ground truth + cand = op.forward(q.to(dtype), k.to(dtype), v.to(dtype), causal=True) + assert cand.dtype == dtype + + err = (cand.float() - ref).abs() + peak = ref.abs().max() + max_abs, mean_abs = err.max().item(), err.mean().item() + print(f"\n[attention {dtype}] max_abs={max_abs:.4g} mean_abs={mean_abs:.4g} peak={peak:.4g}") + assert max_abs <= _DTYPE_REL_PEAK[dtype] * peak.item() + + +def test_output_shape(): + """Output shape is [B, Hq, Sq, D].""" + q, k, v = _qkv(3, 7, 7, seed=3) + out = NativeAttentionOp().forward(q, k, v, causal=True) + assert out.shape == (3, _N_HEADS, 7, _HEAD_DIM) + + +# --------------------------------------------------------------------------- # +# Axis-A batch invariance (bitwise, single-thread reduction order) +# --------------------------------------------------------------------------- # +# A sequence's attention output must not depend on how many other sequences +# share the batch. Compute on the full batch once, then slice -- never compute a +# slice on its own. Requires the pinned single-thread reduction order. +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_batch_invariance_slice(dtype: torch.dtype): + """Axis-A: a sequence's output is bitwise-independent of how many share the batch.""" + op = NativeAttentionOp() + q, k, v = _qkv(8, 16, 16, seed=5, dtype=dtype) + with _single_thread(): + full = op.forward(q, k, v, causal=True) + assert torch.equal(op.forward(q[:1], k[:1], v[:1], causal=True), full[:1]) + assert torch.equal(op.forward(q[3:5], k[3:5], v[3:5], causal=True), full[3:5]) + + +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_batch_invariance_chunked(dtype: torch.dtype): + """Axis-A (chunked): processing the batch in chunks and concatenating == one shot.""" + op = NativeAttentionOp() + q, k, v = _qkv(8, 16, 16, seed=6, dtype=dtype) + with _single_thread(): + full = op.forward(q, k, v, causal=True) + c1 = op.forward(q[:3], k[:3], v[:3], causal=True) + c2 = op.forward(q[3:], k[3:], v[3:], causal=True) + assert torch.equal(torch.cat([c1, c2], dim=0), full) + + +# --------------------------------------------------------------------------- # +# Purity / gradient / registry +# --------------------------------------------------------------------------- # +def test_inputs_not_mutated(): + """Purity: no input tensor is mutated in place.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=7) + qc, kc, vc = q.clone(), k.clone(), v.clone() + mask = torch.ones(2, 8, dtype=torch.bool) + mc = mask.clone() + op.forward(q, k, v, causal=True, key_padding_mask=mask) + op.forward_fp32(q, k, v, causal=True, key_padding_mask=mask) + assert torch.equal(q, qc) and torch.equal(k, kc) and torch.equal(v, vc) + assert torch.equal(mask, mc) + + +def test_gradient_flows(): + """fp32 autograd (the backward golden source) yields finite grads for q, k, v.""" + op = NativeAttentionOp() + q, k, v = _qkv(2, 8, 8, seed=8) + q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True) + op.forward_fp32(q, k, v, causal=True).sum().backward() + for t in (q, k, v): + assert t.grad is not None and t.grad.shape == t.shape + assert torch.isfinite(t.grad).all() + + +def test_registry_dispatches_native_attention_op(): + """The registry resolves "attention" to the ground-truth NativeAttentionOp.""" + assert isinstance(kernel_registry.get_op("attention"), NativeAttentionOp) + + +# --------------------------------------------------------------------------- # +# Qwen3-8B LARGE real-scale GPU smoke test +# --------------------------------------------------------------------------- # +# The scores tensor [B=8, Hq=32, Skv=4096, Skv=4096] is ~17 GB in fp32, so the +# LARGE load point is GPU-only and skips without enough memory. SMALL/MEDIUM at +# real head dims already run on CPU above; this validates real prefill scale. +def _enough_gpu_memory(num_bytes: int) -> bool: + if not torch.cuda.is_available(): + return False + try: + free, _ = torch.cuda.mem_get_info() + except RuntimeError: + return False + return free > int(num_bytes * 1.5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="needs a CUDA GPU") +def test_attention_qwen3_8b_large_real_shape(): + """GPU smoke at LARGE Qwen3-8B prefill (batch=8, seq=4096, 32/8/128).""" + # Runtime memory check (not a collection-time skipif): free memory at + # collection time is not representative on a shared GPU. Naive fp32 attention + # peaks at ~3x the scores tensor (scores + masked copy + softmax probs all + # live transiently), so budget 3x the ~17 GB scores -> ~50 GB peak. This makes + # LARGE an H100-class (H-series nightly) test, skipping on smaller GPUs. + scores_bytes = 8 * _N_HEADS * 4096 * 4096 * 4 # ~17 GB + if not _enough_gpu_memory(scores_bytes * 3): + pytest.skip("not enough free GPU memory for the ~50 GB fp32 LARGE attention peak") + device = torch.device("cuda") + op = NativeAttentionOp() + gen = torch.Generator(device=device).manual_seed(0) + q = torch.randn(8, _N_HEADS, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + k = torch.randn(8, _N_KV, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + v = torch.randn(8, _N_KV, 4096, _HEAD_DIM, generator=gen, dtype=torch.float32, device=device) + out = op.forward_fp32(q, k, v, causal=True) + assert out.shape == (8, _N_HEADS, 4096, _HEAD_DIM) + assert torch.isfinite(out).all() + # Axis-A: compute on full batch, then slice (no per-slice recompute). + assert torch.equal(op.forward_fp32(q[:1], k[:1], v[:1], causal=True), out[:1]) From 256ec1d2bbfcf85b7aa521486852d3b4a7a9f426 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Wed, 24 Jun 2026 15:58:02 +0800 Subject: [PATCH 2/7] fix(ws1): address CodeRabbit review on attention op - standard_attn: define fully key-padding-masked query rows as 0 (was NaN); guarded to the padding branch so the no-pad path is unchanged, row-independent so Axis-A holds; add test_fully_masked_query_returns_zero_not_nan - test: drop the double 1.5x margin in _enough_gpu_memory (LARGE skip now ~50 GB as documented, no longer over-skips 80 GB GPUs) - docs/attention.md: add text lang to the diagram fence (MD040); clarify that dispatch uses forward() input-dtype path, forward_fp32() is the explicit fp32 path --- docs/operators/attention.md | 9 ++++---- .../ops/pytorch/attention/standard_attn.py | 9 ++++++++ tests/test_attention.py | 22 ++++++++++++++++++- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/docs/operators/attention.md b/docs/operators/attention.md index 77feef0..bc0d8fe 100644 --- a/docs/operators/attention.md +++ b/docs/operators/attention.md @@ -13,7 +13,7 @@ friends) are validated against. This op covers **only** the softmax attention. Qwen3's QK-Norm and RoPE are applied *before* the call (see the chain), so the `q`, `k` passed in are already normalized and rotated. -``` +```text q --\ k ----softmax(QKᵀ/√d + mask)·V--> out v --/ @@ -76,9 +76,10 @@ inputs. Masks are built on the inputs' device. `kernel_registry.get_op("attention")` resolves through the `OpBackend` priority map. On `cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op -(`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to the fp32 reference. When fused -attention kernels land, they are prepended to the priority list and the native op becomes the -fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is +(`PYTORCH_NATIVE_ATTENTION`), so every device dispatches to this op. Calling it (`__call__` -> +`forward(...)`) computes in the input dtype; `forward_fp32(...)` is the explicit fp32 golden +path. When fused attention kernels land, they are prepended to the priority list and the native +op becomes the fallback. The production `"attn"` op_type (SDPA-based `PYTORCH_ATTN`, FlashAttention, etc.) is a separate dispatch chain and is unaffected. ## Accuracy diff --git a/rl_engine/kernels/ops/pytorch/attention/standard_attn.py b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py index 8c3c94d..804efb6 100644 --- a/rl_engine/kernels/ops/pytorch/attention/standard_attn.py +++ b/rl_engine/kernels/ops/pytorch/attention/standard_attn.py @@ -166,6 +166,15 @@ def _attention( scores = scores.masked_fill(pad[:, None, None, :], float("-inf")) probs = torch.softmax(scores, dim=-1) # subtracts row max internally + if key_padding_mask is not None: + # A query whose every valid key is padded out has an all -inf row; + # softmax would emit NaN. Define such fully-masked rows as 0 (no key + # contributes), keeping outputs/grads finite. Only key_padding_mask can + # fully mask a row -- causal alone always leaves a query its own key -- + # so this guard lives in the padding branch and the no-pad path is + # untouched. Row-independent, so Axis-A batch invariance still holds. + all_masked = ~torch.isfinite(scores).any(dim=-1, keepdim=True) + probs = torch.where(all_masked, torch.zeros_like(probs), probs) out = torch.matmul(probs, vf) # [B, Hq, Sq, D] return out.to(output_dtype) diff --git a/tests/test_attention.py b/tests/test_attention.py index 7c46684..17ff290 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -264,6 +264,26 @@ def test_key_padding_mask_excludes_padded_keys(): assert torch.equal(masked, valid_only) +# A query whose every key is padded out has an all -inf row; naive softmax would +# emit NaN. The op defines such fully-masked rows as 0, keeping outputs and grads +# finite (NaN would poison both and break alignment against downstream kernels). +def test_fully_masked_query_returns_zero_not_nan(): + """All keys padded out -> the query yields 0 (not NaN), and grads stay finite.""" + gen = torch.Generator().manual_seed(9) + q = torch.randn(1, _N_HEADS, 4, _HEAD_DIM, generator=gen, requires_grad=True) + k = torch.randn(1, _N_KV, 4, _HEAD_DIM, generator=gen) + v = torch.randn(1, _N_KV, 4, _HEAD_DIM, generator=gen) + mask = torch.zeros(1, 4, dtype=torch.bool) # all False == every key is padding + + out = NativeAttentionOp().forward_fp32(q, k, v, causal=False, key_padding_mask=mask) + assert torch.isfinite(out).all() + assert torch.equal(out, torch.zeros_like(out)) + + # NaN would propagate through backward; assert the gradient is finite (zero here). + out.sum().backward() + assert torch.isfinite(q.grad).all() + + # --------------------------------------------------------------------------- # # Axis-B accuracy # --------------------------------------------------------------------------- # @@ -364,7 +384,7 @@ def _enough_gpu_memory(num_bytes: int) -> bool: free, _ = torch.cuda.mem_get_info() except RuntimeError: return False - return free > int(num_bytes * 1.5) + return free > num_bytes @pytest.mark.skipif(not torch.cuda.is_available(), reason="needs a CUDA GPU") From 73a273b49cf105220703503d21a1e0ba8e9ce10f Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Thu, 25 Jun 2026 15:13:44 +0800 Subject: [PATCH 3/7] fix(ws1): run attention tests in CI, relax padding bitwise claim to atol=1e-6 --- .github/workflows/ci.yml | 4 ++++ docs/operators/attention.md | 5 ++++- tests/test_attention.py | 25 +++++++++++++++++++++++-- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 755f426..5c2a2a0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,6 +68,10 @@ jobs: python -m pytest rl_engine/tests/test_dispatch.py -v PYTEST_DISABLE_PLUGIN_AUTOLOAD=1 python -m pytest tests/test_attention_correctness.py -q -rs + - name: Run Attention Ground-Truth Tests (CPU-safe) + run: | + python -m pytest tests/test_attention.py -v -k "not large and not gpu" + docs: runs-on: ubuntu-latest steps: diff --git a/docs/operators/attention.md b/docs/operators/attention.md index bc0d8fe..e4c3c13 100644 --- a/docs/operators/attention.md +++ b/docs/operators/attention.md @@ -110,7 +110,10 @@ out = torch.matmul(probs, vf) # [B, Hq, Sq, D] tolerance, not bitwise equality. - **Axis A — batch invariance**: each query row reduces over the keys independently of how many sequences share the batch, so a row's output is bitwise-identical (`torch.equal`, `atol=0`) - across batch slicing, padding, **and chunked** (chunked-prefill) configurations. + across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax + reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width + (e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small + tolerance (`atol=1e-6`), not bitwise, in IEEE 754. - **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction` row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak): diff --git a/tests/test_attention.py b/tests/test_attention.py index 17ff290..e985e5f 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -245,8 +245,23 @@ def test_scale_default_and_explicit(): # key_padding_mask (True=valid): padded key columns get zero weight, so the # result equals attention computed over only the valid keys. +# +# NOTE: padding changes the softmax reduction width (Skv=10 vs Skv=6). Even +# though masked positions contribute exp(-inf)=0 to the sum, the internal +# reduction order of torch.softmax over a size-10 row vs a size-6 row may +# differ (vectorisation boundaries, intermediate rounding of partial sums), +# so bitwise equality across different Skv is NOT guaranteed in IEEE 754. +# We assert near-equality (atol=1e-6) which validates the masking semantics +# without over-constraining the floating-point reduction path. +_PADDING_ATOL = 1.0e-6 + + def test_key_padding_mask_excludes_padded_keys(): - """key_padding_mask: padded keys get zero weight (== attending over valid keys only).""" + """key_padding_mask: padded keys get zero weight (≈ attending over valid keys only). + + Not bitwise-equal because the softmax reduction width differs (Skv=10 vs 6); + see comment above for rationale. + """ op = NativeAttentionOp() gen = torch.Generator().manual_seed(5) q = torch.randn(2, _N_HEADS, 6, _HEAD_DIM, generator=gen) @@ -261,7 +276,13 @@ def test_key_padding_mask_excludes_padded_keys(): masked = op.forward_fp32(q, k, v, causal=False, key_padding_mask=mask) valid_only = op.forward_fp32(q, k_valid, v_valid, causal=False) - assert torch.equal(masked, valid_only) + + diff = (masked - valid_only).abs() + max_err = diff.max().item() + print(f"\n[padding mask] max_abs_err={max_err:.3g} (threshold={_PADDING_ATOL:.1g})") + assert torch.allclose(masked, valid_only, atol=_PADDING_ATOL, rtol=0.0), ( + f"Padding-masked result diverges from valid-only by {max_err:.3g} > {_PADDING_ATOL}" + ) # A query whose every key is padded out has an all -inf row; naive softmax would From 1e3a99087b00bc257def924e513edbb41ffa7103 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Sat, 27 Jun 2026 15:59:36 +0800 Subject: [PATCH 4/7] fix(ws1): bump padding atol to 2e-6 and clarify Axis-A doc per review key_padding_mask drift over differing reduction widths (Skv=10 vs 6) is ~1.3e-6 and platform-sensitive; atol=1e-6 failed locally for the reviewer. Bump the threshold to 2e-6 for headroom, and update the test-coverage doc line so padding reads as near-equality, not part of the bitwise Axis-A claim. --- docs/operators/attention.md | 3 ++- tests/test_attention.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/operators/attention.md b/docs/operators/attention.md index e4c3c13..afe5e14 100644 --- a/docs/operators/attention.md +++ b/docs/operators/attention.md @@ -139,7 +139,8 @@ python -m pytest tests/test_attention.py -v Covers: `forward_fp32` vs an independent fp32 reference (bitwise), strict-fp32 under hostile autocast/TF32, closed-form causal/decode checks, GQA replication and the divisibility guard, scale defaults, key-padding masking, dtype-path accuracy (Axis-B), output shape, Axis-A batch -invariance (slice + chunked + padding), input purity, gradient flow, registry dispatch, and a +invariance (slice + chunked, bitwise; padding is near-equality only, see below), input purity, +gradient flow, registry dispatch, and a GPU-only LARGE Qwen3-8B real-shape smoke test. ## Implementation Files diff --git a/tests/test_attention.py b/tests/test_attention.py index e985e5f..34e0d90 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -251,9 +251,10 @@ def test_scale_default_and_explicit(): # reduction order of torch.softmax over a size-10 row vs a size-6 row may # differ (vectorisation boundaries, intermediate rounding of partial sums), # so bitwise equality across different Skv is NOT guaranteed in IEEE 754. -# We assert near-equality (atol=1e-6) which validates the masking semantics -# without over-constraining the floating-point reduction path. -_PADDING_ATOL = 1.0e-6 +# We assert near-equality (atol=2e-6) which validates the masking semantics +# without over-constraining the floating-point reduction path. The observed +# drift is ~1.3e-6 and is platform-sensitive, so the threshold carries headroom. +_PADDING_ATOL = 2.0e-6 def test_key_padding_mask_excludes_padded_keys(): From 674e83aee461f73819a43bdee972d69fd804b516 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Sun, 28 Jun 2026 16:33:18 +0800 Subject: [PATCH 5/7] =?UTF-8?q?test(ws1):=20address=20PR=20#188=20review?= =?UTF-8?q?=20=E2=80=94=20gradient=20vs=20fp64=20reference=20(KJLdefeated)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace isfinite-only grad check with autograd through the independent double-precision _ref_softmax_attn under a random (seeded) cotangent. isfinite can't tell a correct gradient from a wrong-but-finite one, and attention's backward (softmax Jacobian + dQ/dK/dV contractions) is the most error-prone in the stack; .sum()'s all-ones cotangent would also collapse the contraction. --- tests/test_attention.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 34e0d90..b9eb598 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -281,9 +281,9 @@ def test_key_padding_mask_excludes_padded_keys(): diff = (masked - valid_only).abs() max_err = diff.max().item() print(f"\n[padding mask] max_abs_err={max_err:.3g} (threshold={_PADDING_ATOL:.1g})") - assert torch.allclose(masked, valid_only, atol=_PADDING_ATOL, rtol=0.0), ( - f"Padding-masked result diverges from valid-only by {max_err:.3g} > {_PADDING_ATOL}" - ) + assert torch.allclose( + masked, valid_only, atol=_PADDING_ATOL, rtol=0.0 + ), f"Padding-masked result diverges from valid-only by {max_err:.3g} > {_PADDING_ATOL}" # A query whose every key is padded out has an all -inf row; naive softmax would @@ -377,15 +377,30 @@ def test_inputs_not_mutated(): assert torch.equal(mask, mc) -def test_gradient_flows(): - """fp32 autograd (the backward golden source) yields finite grads for q, k, v.""" +def test_gradient_matches_reference(): + """fp32 autograd grads match autograd through the double-precision reference. + + isfinite only rules out NaN/Inf -- it can't tell a correct gradient from a + wrong-but-finite one, and attention's backward (softmax Jacobian + dQ/dK/dV + contractions) is the most error-prone in the stack. Backprop a *random* + cotangent (not .sum(), whose all-ones cotangent collapses the contraction) + and compare q/k/v grads against autograd through the independent + _ref_softmax_attn computed in float64 (TF32-immune high-precision golden). + """ op = NativeAttentionOp() q, k, v = _qkv(2, 8, 8, seed=8) q, k, v = q.requires_grad_(True), k.requires_grad_(True), v.requires_grad_(True) - op.forward_fp32(q, k, v, causal=True).sum().backward() - for t in (q, k, v): + + out = op.forward_fp32(q, k, v, causal=True) + gen = torch.Generator().manual_seed(8) + dy = torch.randn(out.shape, generator=gen, dtype=out.dtype) # seeded for reproducibility + out.backward(dy) + + qd, kd, vd = (t.detach().double().requires_grad_(True) for t in (q, k, v)) + _ref_softmax_attn(qd, kd, vd, causal=True).backward(dy.double()) + for t, td in ((q, qd), (k, kd), (v, vd)): assert t.grad is not None and t.grad.shape == t.shape - assert torch.isfinite(t.grad).all() + torch.testing.assert_close(t.grad, td.grad.float(), rtol=1e-4, atol=1e-4) def test_registry_dispatches_native_attention_op(): From 91c6b337e12c2b770e0701f06fc50d1e2f2bed48 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Sun, 28 Jun 2026 16:54:09 +0800 Subject: [PATCH 6/7] docs(ws1): sync attention.md contract with impl per PR #195 review forward_fp32 returns fp32 (not the input dtype), and the shipped padding tolerance is atol=2e-6, not 1e-6. Aligns the operator contract doc with the code (CodeRabbit). --- docs/operators/attention.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/operators/attention.md b/docs/operators/attention.md index afe5e14..e3cb5f9 100644 --- a/docs/operators/attention.md +++ b/docs/operators/attention.md @@ -69,8 +69,9 @@ The op exposes the WS1 dual-path contract: expression is correct for both prefill (`Sq == Skv`) and decode (`Sq < Skv`, one query sees the whole cache). -Pure function — no randomness, no in-place mutation; device and original dtype follow the -inputs. Masks are built on the inputs' device. +Pure function — no randomness, no in-place mutation; device follows the inputs. `forward(...)` +preserves the input dtype, while `forward_fp32(...)` always returns fp32. Masks are built on +the inputs' device. ## Dispatch Behavior @@ -113,7 +114,7 @@ out = torch.matmul(probs, vf) # [B, Hq, Sq, D] across batch slicing **and chunked** (chunked-prefill) configurations — these keep the softmax reduction width fixed. `key_padding_mask` is the exception: padding changes the reduction width (e.g. Skv=10 vs 6), so the masked result only matches the valid-only result up to a small - tolerance (`atol=1e-6`), not bitwise, in IEEE 754. + tolerance (`atol=2e-6`), not bitwise, in IEEE 754. - **Axis B — tolerance**: as a `reduction` op, low-precision tolerance follows the `reduction` row of the WS1 numerical contract. Measured drift vs the fp32 golden path (rel-peak): From 22d45e6fe942623d8157c20f39dcce2dca3f9555 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Sun, 28 Jun 2026 20:44:49 +0800 Subject: [PATCH 7/7] test(ws1): add backward batch-invariance slice test (Flink-ddd) Per PR #188 review: full-batch forward/backward under _single_thread(), then single-batch sliced backward with the corresponding dy slice, and assert q/k/v grad slices are bitwise identical via torch.equal. --- tests/test_attention.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/test_attention.py b/tests/test_attention.py index b9eb598..ffb5558 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -361,6 +361,36 @@ def test_batch_invariance_chunked(dtype: torch.dtype): assert torch.equal(torch.cat([c1, c2], dim=0), full) +@pytest.mark.parametrize("dtype", _DTYPES_AXIS_A) +def test_backward_batch_invariance_slice(dtype: torch.dtype): + """Axis-A (backward): q/k/v grads are bitwise-independent of batch size.""" + op = NativeAttentionOp() + q_full, k_full, v_full = _qkv(8, 16, 16, seed=10, dtype=dtype) + + q_full.requires_grad_(True) + k_full.requires_grad_(True) + v_full.requires_grad_(True) + + gen = torch.Generator().manual_seed(10) + dy = torch.randn(8, _N_HEADS, 16, _HEAD_DIM, generator=gen, dtype=dtype) + + with _single_thread(): + out_full = op.forward(q_full, k_full, v_full, causal=True) + out_full.backward(dy) + + q_slice = q_full[:1].detach().clone().requires_grad_(True) + k_slice = k_full[:1].detach().clone().requires_grad_(True) + v_slice = v_full[:1].detach().clone().requires_grad_(True) + + with _single_thread(): + out_slice = op.forward(q_slice, k_slice, v_slice, causal=True) + out_slice.backward(dy[:1]) + + assert torch.equal(q_slice.grad, q_full.grad[:1]) + assert torch.equal(k_slice.grad, k_full.grad[:1]) + assert torch.equal(v_slice.grad, v_full.grad[:1]) + + # --------------------------------------------------------------------------- # # Purity / gradient / registry # --------------------------------------------------------------------------- #