Skip to content
36 changes: 14 additions & 22 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,28 +2550,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]


@pytest.mark.skipif(
(
get_cudnn_version() < (8, 9, 3)
if cudnn_frontend_version == 0
else get_cudnn_version() < (9, 2, 1)
),
reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""",
get_cudnn_version() < (9, 2, 1),
reason="cuDNN 9.2.1+ is required for FP8 fused attention.",
)
@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn)
@pytest.mark.parametrize("dtype", param_types_fp8)
@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0)
@pytest.mark.parametrize("model", model_configs_fp8)
def test_custom_mha_fp8_vs_f16(dtype, model):
"""Test FP8 dot product attention implementations based on cuDNN frontend
v0.9 and v1.0+. Each test compares results from a custom implementation of
an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA
implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention.
Both paths take F16 input and output. QKV layout is t3hd or bs3hd"""
Both paths take F16 input and output. QKV layout is bs3hd"""

config = model_configs_fp8[model]

Expand All @@ -2580,7 +2573,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=torch.float8_e4m3fn,
qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd",
qkv_layout="bs3hd",
is_training=is_training,
deterministic=_deterministic,
)
Expand Down Expand Up @@ -2787,18 +2780,17 @@ def forward(
quantization_params=qkv_quantizer,
use_split_accumulator=_2X_ACC_FPROP,
)
qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd"
o_format = "bshd" if cudnn_frontend_version == 1 else "thd"
qkv_layout = "bs3hd"
o_format = "bshd"
qkv = qkv.view(-1, 3, h, d)
qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous()
torch.save(qkv_fp16, "qkv.pt")
if cudnn_frontend_version == 1:
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
qkv = qkv.view(b, max_s, 3, h, d) # bs3hd

# FMHA
q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :]
k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :]
v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :]
q_data = qkv._data[:, :, 0, :, :]
k_data = qkv._data[:, :, 1, :, :]
v_data = qkv._data[:, :, 2, :, :]
q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape)
k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape)
v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape)
Expand All @@ -2820,7 +2812,7 @@ def forward(
qkv_layout=qkv_layout,
o_format=o_format,
attn_bias_type="no_bias",
attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
attn_mask_type=mask_type,
rng_gen=None,
o_quantizer=o_quantizer,
s_quantizer=s_quantizer,
Expand Down Expand Up @@ -2887,9 +2879,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
do_format=ctx.o_format,
dqkv_layout=ctx.qkv_layout,
attn_bias_type="no_bias",
attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
attn_mask_type=ctx.mask_type,
)
dim = 2 if cudnn_frontend_version == 1 else 1
dim = 2
dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype)
dqkv_shape = list(dq._data.shape)
dqkv_shape.insert(dim, 3)
Expand Down
75 changes: 30 additions & 45 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,35 +254,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(

if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) &&
sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS &&
// 8.9: t3hd, max_s=512, d=64, padding
((cudnn_runtime_version >= 8900 && sm_arch_ < 100 &&
qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv &&
max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 &&
Comment thread
cyanguwa marked this conversation as resolved.
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
// 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
(cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
// 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal}
(cudnn_runtime_version >= 90700 &&
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) ||
// 9.21: d_qk=192, d_v=128
(cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 &&
head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) &&
(
// 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal}
(cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 &&
max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) ||
// 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal}
(cudnn_runtime_version >= 90700 &&
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
Comment thread
sudhakarsingh27 marked this conversation as resolved.
// sm90: fwd d<=256, bwd d=128 only
// sm100: fwd d<=128, bwd d<=128
((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) ||
(sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) ||
(sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) &&
head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) ||
// 9.21: d_qk=192, d_v=128
(cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 &&
head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) &&
// pre-9.21: {bshd, sbhd}, {vanilla}
// 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable}
((cudnn_runtime_version < 92100 &&
Expand All @@ -294,14 +290,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
!requires_64bit_ragged_offset &&
// 9.10.0: known bugs with SDPA FP8
(cudnn_runtime_version != 91000) && !return_max_logit) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
" Please upgrade your cuDNN version if possible."
<< std::endl;
}
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_arb = false;
if (
Expand Down Expand Up @@ -727,10 +716,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
size_t i = 0;
const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
const Tensor *input_ZInv = nullptr;
if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) {
input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
}
const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
const Tensor *input_SoftmaxOffset = nullptr;
if (softmax_type != NVTE_VANILLA_SOFTMAX) {
Expand All @@ -744,10 +729,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format,
do_scale_inv_format, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, bottom_right_diagonal, deterministic,
input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M,
input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, output_dQ,
output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q,
input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle);
input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S,
input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv,
input_rng_state, wkspace, stream, handle);
} else {
NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n");
}
Expand Down
Loading
Loading