Problem Description
aiter.flash_attn_varlen_func produces numerically incorrect results when called with max_seqlen_q=1 (single-token decode). For FP16 the error is catastrophic (MAE up to 64). For BF16, only head_size=128 passes; head_size=64 and head_size=256 also fail. All prefill configurations (q_len ≥ 8) pass for both dtypes.
Operating System
Ubuntu 22.04.5 LTS (Jammy Jellyfish)"
CPU
AMD EPYC 9575F 64-Core Processor
GPU
8 x AMD Instinct MI325 (amdgcn-amd-amdhsa--gfx942:sramecc+:xnack)
ROCm Version
AMD-SMI 26.0.0+37d158ab amdgpu version: 6.12.12 ROCm version: 7.0.0
ROCm Component
No response
Steps to Reproduce
import torch
import aiter
torch.set_default_device("cuda")
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
num_q_heads = num_kv_heads = 8
head_size = 128
q_len = 1
kv_len = 512
scale = head_size**-0.5
for dtype in [torch.bfloat16, torch.float16]:
q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype)
k = torch.randn(kv_len, num_kv_heads, head_size, dtype=dtype)
v = torch.randn_like(k)
out = torch.empty_like(q)
aiter.flash_attn_varlen_func(
q=q, k=k, v=v,
cu_seqlens_q=torch.tensor([0, q_len], dtype=torch.int32, device="cuda"),
cu_seqlens_k=torch.tensor([0, kv_len], dtype=torch.int32, device="cuda"),
max_seqlen_q=q_len, max_seqlen_k=kv_len, min_seqlen_q=1,
dropout_p=0.0, softmax_scale=scale, causal=True,
window_size=(-1, -1), alibi_slopes=None, return_lse=False, out=out,
)
# Float32 naive reference
attn = torch.einsum("qhd,khd->hqk", q.float() * scale, k.float())
attn = torch.softmax(attn, dim=-1).to(dtype)
ref = torch.einsum("hqk,khd->qhd", attn, v)
mae = (out.float() - ref.float()).abs().mean().item()
print(f"{str(dtype):<20} MAE = {mae:.6f} {'PASS' if mae < 0.02 else 'FAIL'}")
Output on gfx942:
torch.bfloat16 MAE = 0.000117 PASS
torch.float16 MAE = 1.754807 FAIL
Additional Information
The bug is specific to q_len=1 (decode mode). All prefill configurations (q_len ≥ 8) pass for both dtypes. BF16+head_size=128 is the only decode configuration that passes reliably. Decode inference for FP16 models using RocmAiterFABackend / flash_attn_varlen_func will produce incorrect attention outputs on every decode step.
You can reproduce this issue using the nightly vLLM docker: docker pull rocm/vllm-dev:nightly
cc @kenroche
Problem Description
aiter.flash_attn_varlen_funcproduces numerically incorrect results when called withmax_seqlen_q=1(single-token decode). For FP16 the error is catastrophic (MAE up to 64). For BF16, onlyhead_size=128passes;head_size=64andhead_size=256also fail. All prefill configurations (q_len ≥ 8) pass for both dtypes.Operating System
Ubuntu 22.04.5 LTS (Jammy Jellyfish)"
CPU
AMD EPYC 9575F 64-Core Processor
GPU
8 x AMD Instinct MI325 (amdgcn-amd-amdhsa--gfx942:sramecc+:xnack)
ROCm Version
AMD-SMI 26.0.0+37d158ab amdgpu version: 6.12.12 ROCm version: 7.0.0
ROCm Component
No response
Steps to Reproduce
Output on gfx942:
Additional Information
The bug is specific to q_len=1 (decode mode). All prefill configurations (q_len ≥ 8) pass for both dtypes. BF16+head_size=128 is the only decode configuration that passes reliably. Decode inference for FP16 models using RocmAiterFABackend / flash_attn_varlen_func will produce incorrect attention outputs on every decode step.
You can reproduce this issue using the nightly vLLM docker:
docker pull rocm/vllm-dev:nightlycc @kenroche