Skip to content

[Issue]: flash_attn_varlen_func incorrect results for FP16 + q_len=1 (decode) on gfx942 #2229

@AndreasKaratzas

Description

@AndreasKaratzas

Problem Description

aiter.flash_attn_varlen_func produces numerically incorrect results when called with max_seqlen_q=1 (single-token decode). For FP16 the error is catastrophic (MAE up to 64). For BF16, only head_size=128 passes; head_size=64 and head_size=256 also fail. All prefill configurations (q_len ≥ 8) pass for both dtypes.

Operating System

Ubuntu 22.04.5 LTS (Jammy Jellyfish)"

CPU

AMD EPYC 9575F 64-Core Processor

GPU

8 x AMD Instinct MI325 (amdgcn-amd-amdhsa--gfx942:sramecc+:xnack)

ROCm Version

AMD-SMI 26.0.0+37d158ab amdgpu version: 6.12.12 ROCm version: 7.0.0

ROCm Component

No response

Steps to Reproduce

import torch
import aiter

torch.set_default_device("cuda")
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

num_q_heads = num_kv_heads = 8
head_size = 128
q_len = 1
kv_len = 512
scale = head_size**-0.5

for dtype in [torch.bfloat16, torch.float16]:
    q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype)
    k = torch.randn(kv_len, num_kv_heads, head_size, dtype=dtype)
    v = torch.randn_like(k)
    out = torch.empty_like(q)

    aiter.flash_attn_varlen_func(
        q=q, k=k, v=v,
        cu_seqlens_q=torch.tensor([0, q_len], dtype=torch.int32, device="cuda"),
        cu_seqlens_k=torch.tensor([0, kv_len], dtype=torch.int32, device="cuda"),
        max_seqlen_q=q_len, max_seqlen_k=kv_len, min_seqlen_q=1,
        dropout_p=0.0, softmax_scale=scale, causal=True,
        window_size=(-1, -1), alibi_slopes=None, return_lse=False, out=out,
    )

    # Float32 naive reference
    attn = torch.einsum("qhd,khd->hqk", q.float() * scale, k.float())
    attn = torch.softmax(attn, dim=-1).to(dtype)
    ref = torch.einsum("hqk,khd->qhd", attn, v)

    mae = (out.float() - ref.float()).abs().mean().item()
    print(f"{str(dtype):<20}  MAE = {mae:.6f}  {'PASS' if mae < 0.02 else 'FAIL'}")

Output on gfx942:

torch.bfloat16        MAE = 0.000117  PASS
torch.float16         MAE = 1.754807  FAIL

Additional Information

The bug is specific to q_len=1 (decode mode). All prefill configurations (q_len ≥ 8) pass for both dtypes. BF16+head_size=128 is the only decode configuration that passes reliably. Decode inference for FP16 models using RocmAiterFABackend / flash_attn_varlen_func will produce incorrect attention outputs on every decode step.

You can reproduce this issue using the nightly vLLM docker: docker pull rocm/vllm-dev:nightly

cc @kenroche

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions