From 2b9fbc5c3dc59e86187350ba875eee819e8f2bf0 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 5 May 2026 18:50:17 -0700 Subject: [PATCH 1/3] refactor nvte_get_fused_attn_backend with FE calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 516 ++++++------------ .../fused_attn_f16_arbitrary_seqlen.cu | 136 +++++ .../fused_attn_f16_arbitrary_seqlen.h | 22 + .../common/fused_attn/fused_attn_fp8.cu | 101 ++++ .../common/fused_attn/fused_attn_fp8.h | 25 + .../include/transformer_engine/fused_attn.h | 50 +- .../jax/csrc/extensions/attention.cpp | 32 +- .../pytorch/csrc/extensions/attention.cpp | 17 +- 8 files changed, 539 insertions(+), 360 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 141767b803..615f7c2a03 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -226,357 +226,189 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { } } +namespace { + +// Per-thread storage for the message string handed back through +// NVTEFusedAttnBackendStatus::message. Re-used (cleared + re-populated) on every call to +// nvte_get_fused_attn_backend on this thread, which is exactly the lifetime documented in the +// public header. +thread_local std::string g_fused_attn_backend_status_buffer; + +// Apply (code, msg) to *out_status (if non-null), routing the message through the +// thread-local buffer so the returned `const char*` outlives this function call. +void set_status(NVTEFusedAttnBackendStatus *out_status, cudnn_frontend::error_code_t code, + const std::string &message) { + if (out_status == nullptr) return; + g_fused_attn_backend_status_buffer = message; + out_status->code = static_cast(code); + out_status->message = g_fused_attn_backend_status_buffer.c_str(); +} + +void set_status(NVTEFusedAttnBackendStatus *out_status, const cudnn_frontend::error_t &err) { + set_status(out_status, err.code, err.err_msg); +} + +void set_ok(NVTEFusedAttnBackendStatus *out_status) { + set_status(out_status, cudnn_frontend::error_code_t::OK, ""); +} + +} // namespace + // select a backend for fused attention +// +// Routing flow: +// 1. Apply TE post-filters that encode policies cuDNN-FE doesn't model directly: +// a. requires_64bit_ragged_offset -> cudnn >= 9.5 +// b. qkv_format == THD requires a padding-style mask +// c. cuDNN <= 9.15 + is_training + bshd/sbhd + max_seqlen_kv % 128 != 0 + +// cuda_graph + non-padding mask is rejected (known capture quirk) +// 2. Dispatch by dtype to the appropriate probe(s): +// - FP8 (E4M3/E5M2): is_supported_fp8_fwd (+ is_supported_fp8_bwd if training) +// - FP16/BF16: is_supported_f16_fwd (+ is_supported_f16_bwd if training) +// The probes call the same _impl that the executor uses, with workspace=nullptr. +// They run validate -> build_operation_graph -> create_execution_plans -> +// check_support -> build_plans, and populate a thread-local cache that the +// executor cache-hits on. +// 3. Return the selected backend, or NVTE_No_Backend if any probe rejects. +// +// When `out_status` is non-null, it is filled with a code + message describing the +// rejection (or {OK, ""} on success). TE post-filter rejections synthesize an +// INVALID_VALUE entry; probe rejections forward the cuDNN-FE / NVTE_CHECK error verbatim. NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, + NVTEFusedAttnBackendStatus *out_status) { using namespace transformer_engine; - NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - const int device_id = cuda::current_device(); - const int sm_arch_ = cuda::sm_arch(device_id); + // Initialize to OK so callers get a clean status on the success path without us having to + // remember to set it at every return. + set_ok(out_status); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - auto cudnn_runtime_version = cudnnGetVersion(); - // For ragged offsets we only support 32-bit prior to cuDNN 9.5 - // Only used when THD format is requested. + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const auto cudnn_runtime_version = cudnnGetVersion(); + + // ---------- TE post-filters (apply before delegating to cuDNN-FE) ---------- + + // (1) Ragged-offset width: cuDNN < 9.5 only supports 32-bit offsets. const bool requires_64bit_ragged_offset = (qkv_format == NVTE_THD && fused_attn::get_ragged_offset_dtype( layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); - const bool supported_ragged_offset_size = - (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - - if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && - sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // 8.9: t3hd, max_s=512, d=64, padding - ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: d_qk=192, d_v=128 - (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && - head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && - // pre-9.21: {bshd, sbhd}, {vanilla} - // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} - ((cudnn_runtime_version < 92100 && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) || - (cudnn_runtime_version >= 92100 && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD))) && - !requires_64bit_ragged_offset && - // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000) && !return_max_logit) { - if (cudnn_runtime_version >= 8900) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { - bool flag_m512 = false; - bool flag_arb = false; - if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && - (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && - ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - max_seqlen_q == max_seqlen_kv) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && - ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && - ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) { - flag_m512 = true; + if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "Configuration requires 64-bit ragged offsets, which require cuDNN >= 9.5."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + // (2) THD requires a padding-style mask. + if (qkv_format == NVTE_QKV_Format::NVTE_THD && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "THD-format attention requires a padding-style mask " + "(PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT)."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + // (3) cuDNN-Graph capture quirk on cuDNN <= 9.15: training + bshd/sbhd with + // max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask hangs/miscompiles. + if (cudnn_runtime_version <= 91500 && is_training && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + (max_seqlen_kv % 128 != 0) && cuda_graph && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "Known cuDNN <= 9.15 capture quirk: training + bshd/sbhd + " + "max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask is unsupported."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; + } + + // ---------- Dispatch by dtype ---------- + + // Probes use a single-batch graph; capability checks in cuDNN-FE are batch-agnostic. + constexpr size_t probe_batch = 1; + // bottom_right_diagonal is a runtime API knob the router doesn't see; the BRCM-via-mask + // case is captured by attn_mask_type, so we probe with the default top-left alignment. + constexpr bool probe_bottom_right_diagonal = false; + + const bool is_fp8 = + (q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2); + const bool is_f16_or_bf16 = + (q_dtype == NVTEDType::kNVTEFloat16 || q_dtype == NVTEDType::kNVTEBFloat16); + + if (is_fp8) { + // TE-only FP8 post-filters: no 64-bit ragged offsets, no max-logit output. + if (requires_64bit_ragged_offset) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention does not support 64-bit ragged offsets."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if ( - // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // architecture - ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || - (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || - (cudnn_runtime_version >= 90700 && sm_arch_ >= 100)) && - // sequence length - ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || - (cudnn_runtime_version >= 90000)) && - // number of heads - ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 8907)) && - // head dimension - // multiples of 8 - (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && - // <= 128 - ((head_dim_qk <= 128 && head_dim_v <= 128) || - // 9.1: <= 256 + Hopper + fprop - // 9.5: <= 256 + Hopper + bprop - (head_dim_qk <= 256 && head_dim_v <= 256 && - ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || - (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || - // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 - (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && - layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10.2: any head_dim + any arch + fprop + paged - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91002 && - (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || - (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || - // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged - (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && - cudnn_runtime_version >= 91100)) && - // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed - (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && - head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && - head_dim_qk != head_dim_v))) && - // bias type - ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (cudnn_runtime_version >= 8906 && - (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - (bias_type == NVTE_Bias_Type::NVTE_ALIBI && - attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - sm_arch_ >= 90) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - (cudnn_runtime_version >= 90000 && - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && - // mask type - // pre-8.9.6: causal - ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} - (cudnn_runtime_version >= 8906 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.1: adds thd + {padding, padding_causal} - (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90300 && - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right} - (cudnn_runtime_version >= 90500 && - layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) - (cudnn_runtime_version >= 90600 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} - // for any q_format/kv_format, and paged/non-paged - (cudnn_runtime_version >= 90700 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - ((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && - max_seqlen_q <= max_seqlen_kv)))) && - // bias + mask combination - (!(cudnn_runtime_version >= 8906 && - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - // qkv format - (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_BHSD || - (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && - ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - cudnn_runtime_version >= 90600)) || - ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || - q_format == NVTE_QKV_Format::NVTE_BHSD || - (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || - kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || - kv_format == NVTE_QKV_Format::NVTE_BHSD || - (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && - cudnn_runtime_version >= 90700)) && - // sliding window - // pre-9.2: full attn, causal - ((cudnn_runtime_version < 90200 && window_size_left == -1 && - (window_size_right == -1 || window_size_right == 0)) || - // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} - (cudnn_runtime_version >= 90200 && - ((window_size_left == -1 && window_size_right == -1 && - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) || - ((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q == max_seqlen_kv)) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || - // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} - (cudnn_runtime_version >= 90600 && - ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || - ((window_size_left >= 0 || window_size_left == -1) && - (window_size_right >= 0 || window_size_right == -1) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - // TODO(cyang): fix bug for BRCM + cross-attention on sm100 - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700)))) || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && - (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && - cudnn_runtime_version <= 90700) || - cudnn_runtime_version > 90700))))) && - max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - dropout == 0.0)))) && - // check 64-bit ragged offset support - (supported_ragged_offset_size) && - // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && - // softmax type - // pre-9.13.1: vanilla - // 9.13.1+: vanilla, off-by-one, learnable - (cudnn_runtime_version >= 91301 || - (cudnn_runtime_version < 91301 && - softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) && - // determinism on Blackwell - // pre-9.18.1: fwd: deterministic; bwd: non-deterministic - // 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic - (sm_arch_ < 100 || - (sm_arch_ >= 100 && (!is_training || - (is_training && !deterministic && - (dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) || - (is_training && deterministic && cudnn_runtime_version >= 91801 && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { - flag_arb = true; + if (return_max_logit) { + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention does not support return_max_logit."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; + const DType q_t = static_cast(q_dtype); + const DType o_t = static_cast(o_dtype); + auto fwd_status = is_supported_fp8_fwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, o_t, scaling_mode, + handle); + if (fwd_status.is_bad()) { + set_status(out_status, fwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - if (flag_arb == true) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } else if ((flag_arb == false) && (flag_m512 == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; - } - int env_backend = static_cast(backend); - env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && - flag_m512) || - ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && - flag_arb)) { - backend = static_cast(env_backend); + if (is_training) { + auto bwd_status = is_supported_fp8_bwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, + o_t, scaling_mode, handle); + if (bwd_status.is_bad()) { + set_status(out_status, bwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } - if (cudnn_runtime_version < 8901 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if (cudnn_runtime_version < 8900 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-causal) and " - "max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. " - " Please upgrade your cuDNN version if possible." - << std::endl; - } - if ((cudnn_runtime_version <= 91500) && is_training && - (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - (max_seqlen_kv % 128 != 0) && cuda_graph && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && - (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of attention mask (non-padding)," - " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" - " backward fused attention with graph capture requires cuDNN 9.15.1+. " - "Please upgrade your cuDNN version if possible." - << std::endl; + return NVTE_Fused_Attn_Backend::NVTE_FP8; + } + + if (is_f16_or_bf16) { + const DType q_t = static_cast(q_dtype); + auto fwd_status = is_supported_f16_fwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, + handle); + if (fwd_status.is_bad()) { + set_status(out_status, fwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen && sm_arch_ == 120) { - if (cudnn_runtime_version < 91801) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of sm_arch_ == 120 and cudnn_runtime_version < " - "91801 is not supported. " - << " Please upgrade your cuDNN version if possible." << std::endl; - } else if (deterministic && is_training) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Deterministic fused attention on SM120 is not supported." - << std::endl; - } else { - // Known missing support for T3HD/TH3D layouts on SM120 - const bool is_t3hd_or_th3d = - (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD || qkv_layout == NVTE_QKV_Layout::NVTE_TH3D); - if (is_t3hd_or_th3d) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: Given combination of T3HD/TH3D layouts on SM120 is not supported. " - << " Please consider using other THD layouts if possible." << std::endl; - } + if (is_training) { + auto bwd_status = is_supported_f16_bwd( + probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, + head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, + handle); + if (bwd_status.is_bad()) { + set_status(out_status, bwd_status); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - return backend; + + set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, + "Unsupported Q dtype for fused attention " + "(only FP16/BF16/FP8_E4M3/FP8_E5M2 are routable)."); + return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } // NVTE fused attention FWD with separate Q, K and V @@ -661,11 +493,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(output_O->data.dtype); + const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, false); + is_training, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, + window_size_right, return_max_logit, cuda_graph, /*deterministic=*/false, handle, + /*out_status=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, @@ -747,11 +582,14 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); + const NVTEDType O_type = static_cast(input_O->data.dtype); + const NVTEScalingMode scaling_mode = input_Q->scaling_mode; NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, - cuda_graph, deterministic); + /*is_training=*/true, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, + attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, + window_size_left, window_size_right, /*return_max_logit=*/false, cuda_graph, deterministic, + handle, /*out_status=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 6df7ad35c8..57ca14a3e1 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1333,4 +1333,140 @@ void fused_attn_arbitrary_seqlen_bwd( NVTE_ERROR("Unexpected workspace_size."); } } + +namespace { +// Probe-time defaults for runtime-only quantities the router doesn't see (paged-KV dims, +// ragged max-tokens, bias dims). These produce a graph whose support surface matches the +// real executor's: for non-paged / non-ragged paths these are unused inside the impl; +// for ragged-THD we rebind to worst-case bounds; for paged we use 1 page of full s_kv per +// batch (= same dims as non-paged), so cuDNN-FE applies the paged-attention support rules. +struct ProbeDims { + int64_t max_b; + int64_t max_t_q; + int64_t max_t_kv; + int64_t num_pages_k; + int64_t num_pages_v; + int64_t page_size_k; + int64_t page_size_v; + int64_t max_pages_per_seq_k; + int64_t max_pages_per_seq_v; + int64_t bias_b; + int64_t bias_h; + int64_t bias_sq; + int64_t bias_skv; +}; + +ProbeDims compute_probe_dims(int64_t batch, int64_t num_attn_heads, int64_t max_seqlen_q, + int64_t max_seqlen_kv, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type) { + const NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + const NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + const bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + const bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); + const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + const bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); + const bool has_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); + + ProbeDims d{}; + d.max_b = (is_ragged_q || is_ragged_kv) ? batch : 0; + d.max_t_q = is_ragged_q ? batch * max_seqlen_q : 0; + d.max_t_kv = is_ragged_kv ? batch * max_seqlen_kv : 0; + d.num_pages_k = is_paged_kv ? batch : 0; + d.num_pages_v = is_paged_kv ? batch : 0; + d.page_size_k = is_paged_kv ? max_seqlen_kv : 0; + d.page_size_v = is_paged_kv ? max_seqlen_kv : 0; + d.max_pages_per_seq_k = is_paged_kv ? 1 : 0; + d.max_pages_per_seq_v = is_paged_kv ? 1 : 0; + d.bias_b = has_bias ? batch : 0; + d.bias_h = has_bias ? num_attn_heads : 0; + d.bias_sq = has_bias ? max_seqlen_q : 0; + d.bias_skv = has_bias ? max_seqlen_kv : 0; + return d; +} +} // namespace + +cudnn_frontend::error_t is_supported_f16_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, + bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle) { + const ProbeDims d = compute_probe_dims(static_cast(batch), + static_cast(num_attn_heads), + static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), qkv_layout, + bias_type); + const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); + + size_t workspace_size = 0; + try { + fused_attn::fused_attn_arbitrary_seqlen_fwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.num_pages_k, + d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, + d.max_pages_per_seq_v, d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, + return_max_logit, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, + /*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr, + /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, + /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, + /*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_f16_fwd: unknown failure"}; + } +} + +cudnn_frontend::error_t is_supported_f16_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle) { + const ProbeDims d = compute_probe_dims(static_cast(batch), + static_cast(num_attn_heads), + static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), qkv_layout, + bias_type); + const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); + const NVTE_QKV_Format do_format = o_format; + const NVTE_QKV_Layout dqkv_layout = qkv_layout; + + size_t workspace_size = 0; + try { + fused_attn::fused_attn_arbitrary_seqlen_bwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.bias_b, d.bias_h, + d.bias_sq, d.bias_skv, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, do_format, + dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + bottom_right_diagonal, deterministic, /*devPtrQ=*/nullptr, /*devPtrKTranspose=*/nullptr, + /*devPtrVTranspose=*/nullptr, /*devPtrO=*/nullptr, /*devPtrSoftmaxStats=*/nullptr, + /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrdQ=*/nullptr, + /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, /*devPtrdO=*/nullptr, + /*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr, + /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, + /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, + get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_f16_bwd: unknown failure"}; + } +} + } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 8f79b5bb4a..38cf48c1f0 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -12,6 +12,7 @@ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ #include +#include #include "common/common.h" #include "transformer_engine/fused_attn.h" @@ -47,6 +48,27 @@ void fused_attn_arbitrary_seqlen_bwd( const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); +// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> +// check_support -> build_plans) for an F16/BF16 forward graph with the given configuration. +// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, +// the built graph is inserted into the same thread-local cache used by +// fused_attn_arbitrary_seqlen_fwd_impl, so the executor cache-hits on matching descriptors. +// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. +cudnn_frontend::error_t is_supported_f16_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, + bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle); + +// Probe: same as above for the F16/BF16 backward graph. +cudnn_frontend::error_t is_supported_f16_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle); + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d97f388459..c9f7a9ee76 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2991,4 +2991,105 @@ void fused_attn_fp8_bwd( return; } } + +cudnn_frontend::error_t is_supported_fp8_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle) { + // FP8 fwd impl rejects any qkv_format other than BSHD/SBHD/BHSD with NVTE_ERROR; mirror that + // here so the probe returns a typed rejection instead of catching the throw. + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && + qkv_format != NVTE_QKV_Format::NVTE_BHSD) { + return {cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + } + size_t workspace_size = 0; + try { + fused_attn::fused_attn_fp8_fwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout, + qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, bottom_right_diagonal, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, + /*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, + /*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, + /*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr, + /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, + /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, + /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), + get_cudnn_fe_dtype(o_dtype), scaling_mode, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_fp8_fwd: unknown failure"}; + } +} + +cudnn_frontend::error_t is_supported_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle) { + const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && + qkv_format != NVTE_QKV_Format::NVTE_BHSD) { + return {cudnn_frontend::error_code_t::INVALID_VALUE, + "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + } + // For FP8 bwd, dO data type matches O data type and dQKV data type matches Q data type + // (this mirrors the assumption used by callers of fused_attn_fp8_bwd in TE). + const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(q_dtype); + const cudnn_frontend::DataType_t o_t = get_cudnn_fe_dtype(o_dtype); + const cudnn_frontend::DataType_t do_t = o_t; + const cudnn_frontend::DataType_t dqkv_t = qkv_t; + size_t workspace_size = 0; + try { + fused_attn::fused_attn_fp8_bwd_impl( + static_cast(batch), static_cast(num_attn_heads), + static_cast(num_gqa_groups), static_cast(max_seqlen_q), + static_cast(max_seqlen_kv), static_cast(head_dim_qk), + static_cast(head_dim_v), /*scaling_factor=*/1.0f, p_dropout, qkv_layout, + /*o_format=*/qkv_format, /*do_format=*/qkv_format, /*dqkv_layout=*/qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + deterministic, + /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrM=*/nullptr, + /*devPtrO=*/nullptr, /*devPtrdO=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, + /*devPtrdQ=*/nullptr, /*devPtrdK=*/nullptr, /*devPtrdV=*/nullptr, + /*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDescaleQ=*/nullptr, + /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, /*devPtrDescaleO=*/nullptr, + /*devPtrDescaledO=*/nullptr, /*devPtrDescaleS=*/nullptr, /*devPtrDescaledP=*/nullptr, + /*devPtrScaleS=*/nullptr, /*devPtrScaledP=*/nullptr, /*devPtrScaledQ=*/nullptr, + /*devPtrScaledK=*/nullptr, /*devPtrScaledV=*/nullptr, /*devPtrAmaxdP=*/nullptr, + /*devPtrAmaxdQ=*/nullptr, /*devPtrAmaxdK=*/nullptr, /*devPtrAmaxdV=*/nullptr, + /*devPtrQ_t=*/nullptr, /*devPtrK_t=*/nullptr, /*devPtrdO_f16=*/nullptr, + /*devPtrdO_t=*/nullptr, /*devPtrDescaleQ_t=*/nullptr, /*devPtrDescaleK_t=*/nullptr, + /*devPtrDescaledO_t=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, + /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, + /*devPtrDropoutOffset=*/nullptr, qkv_t, o_t, do_t, dqkv_t, scaling_mode, + /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, + /*workspace=*/nullptr, &workspace_size, + /*stream=*/static_cast(0), handle); + return {cudnn_frontend::error_code_t::OK, ""}; + } catch (const std::exception &e) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + } catch (...) { + return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, + "is_supported_fp8_bwd: unknown failure"}; + } +} + } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index aaf5039eeb..5c7f11d80e 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -8,6 +8,8 @@ * \brief Functions for fused attention for FP8 with seqlen <= 512 */ +#include + #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -39,4 +41,27 @@ void fused_attn_fp8_bwd( const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + +// Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> +// check_support -> build_plans) for an FP8 forward graph with the given configuration. +// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, +// the built graph is inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl. +// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. +cudnn_frontend::error_t is_supported_fp8_fwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, + cudnnHandle_t handle); + +// Probe: same as above for the FP8 backward graph. +cudnn_frontend::error_t is_supported_fp8_bwd( + size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, + NVTEScalingMode scaling_mode, cudnnHandle_t handle); +>>>>>>> c9006435 (refactor nvte_get_fused_attn_backend with FE calls) } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 912dc32d35..787e97d628 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -11,6 +11,8 @@ #ifndef TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ #define TRANSFORMER_ENGINE_FUSED_ATTN_FP8_H_ +#include + #include "stdint.h" #include "transformer_engine.h" @@ -196,11 +198,40 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); +/*! \struct NVTEFusedAttnBackendStatus + * \brief Diagnostic info from \c nvte_get_fused_attn_backend. + * + * Filled by \c nvte_get_fused_attn_backend when the caller passes a non-NULL pointer. + * When the routing decision is supported, \c code is 0 and \c message is the empty + * string. When the routing rejects the configuration, \c code is the underlying + * cuDNN-FE \c cudnn_frontend::error_code_t cast to \c int (TE-synthesized post-filter + * rejections use \c INVALID_VALUE), and \c message is a null-terminated human-readable + * reason that points into per-thread storage owned by TE. The pointer is valid only + * until the next call to \c nvte_get_fused_attn_backend on the same thread. + */ +typedef struct NVTEFusedAttnBackendStatus { + int code; + const char *message; +} NVTEFusedAttnBackendStatus; + /*! \brief Get fused attention backend based on input parameters. + * + * Authoritative routing: when a non-NVTE_No_Backend value is returned, the configuration + * is guaranteed to compile through cuDNN-FE (validate -> build_operation_graph -> + * create_execution_plans -> check_support -> build_plans). The router applies a small + * set of TE-specific post-filters in addition to delegating to cuDNN-FE for capability + * checks. On success the built plan is cached, so the executor avoids rebuilding. * * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] o_dtype The data type of output Tensor O. Used by the FP8 + * branch to disambiguate FP8 vs HALF/BF16 output; + * ignored by the F16/BF16 branch (pass q_dtype). + * \param[in] scaling_mode Scaling mode of the input tensors. Used by the FP8 + * branch to select among delayed/current/MXFP8 recipes; + * ignored by the F16/BF16 branch + * (pass NVTE_DELAYED_TENSOR_SCALING). * \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] bias_type The attention bias type. * \param[in] attn_mask_type The attention mask type. @@ -217,13 +248,22 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] deterministic Whether determinism is required or not. + * \param[in] handle cuDNN handle used for the support chain. Required. + * \param[out] out_status Optional. When non-NULL, populated with a code + + * message describing why the configuration was + * rejected (NVTE_No_Backend) or with code=0 and + * message="" on success. The message buffer lives in + * thread-local storage and is overwritten on every + * call on the same thread. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic); + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, + NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, + size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, + size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, + bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, + NVTEFusedAttnBackendStatus *out_status); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 76f2d92891..c6a8897089 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "../extensions.h" +#include "common/cudnn_utils.h" #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -17,11 +18,14 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right, bool deterministic) { + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, static_cast(q_dtype), static_cast(kv_dtype), + static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, + mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, + kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, handle, + /*out_status=*/nullptr); return backend; } @@ -272,11 +276,13 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + auto _handle_fwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, static_cast(dtype), static_cast(dtype), + static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, + softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, + /*cuda_graph=*/false, deterministic, _handle_fwd, /*out_status=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -548,11 +554,13 @@ static void FusedAttnBackwardImpl( /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); + auto _handle_bwd = cudnnExecutionPlanManager::Instance().GetHandle(); auto backend = nvte_get_fused_attn_backend( - is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false, false, deterministic); + is_training, static_cast(dtype), static_cast(dtype), + static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, + softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, + qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, + /*cuda_graph=*/false, deterministic, _handle_bwd, /*out_status=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index e6781bd58a..d67cd4a6b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -6,6 +6,7 @@ #include "../extensions.h" #include "common.h" +#include "common/cudnn_utils.h" #include "pybind.h" namespace { @@ -40,17 +41,25 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s namespace transformer_engine::pytorch { // get the fused attention backend +// +// NOTE: the underlying nvte_get_fused_attn_backend now takes o_dtype and scaling_mode in +// addition to q_dtype/kv_dtype. For the F16/BF16 routing path those are ignored, so we pass +// q_dtype as o_dtype and DELAYED_TENSOR_SCALING. This Python-facing wrapper therefore keeps +// its existing signature; FP8 callers that want authoritative routing for non-default scaling +// recipes should add o_dtype / scaling_mode parameters in a follow-up. NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) { + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, deterministic); + is_training, static_cast(q_dtype), static_cast(kv_dtype), + static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, + attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, + return_max_logit, cuda_graph, deterministic, handle, /*out_status=*/nullptr); return fused_attention_backend; } From 16b837cd0f3e27b7638bfb2a90d39056680a6b6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 01:58:25 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/fused_attn/fused_attn.cpp | 4 +-- .../fused_attn_f16_arbitrary_seqlen.cu | 34 +++++++++---------- .../common/fused_attn/fused_attn_fp8.cu | 12 +++---- .../pytorch/csrc/extensions/attention.cpp | 4 +-- 4 files changed, 26 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 615f7c2a03..95405c0d6f 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -370,8 +370,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( auto bwd_status = is_supported_fp8_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, - o_t, scaling_mode, handle); + window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, o_t, + scaling_mode, handle); if (bwd_status.is_bad()) { set_status(out_status, bwd_status); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 57ca14a3e1..1ced84755c 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1391,11 +1391,10 @@ cudnn_frontend::error_t is_supported_f16_fwd( bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle) { - const ProbeDims d = compute_probe_dims(static_cast(batch), - static_cast(num_attn_heads), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv), qkv_layout, - bias_type); + const ProbeDims d = + compute_probe_dims(static_cast(batch), static_cast(num_attn_heads), + static_cast(max_seqlen_q), static_cast(max_seqlen_kv), + qkv_layout, bias_type); const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); size_t workspace_size = 0; @@ -1405,17 +1404,17 @@ cudnn_frontend::error_t is_supported_f16_fwd( static_cast(num_gqa_groups), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), static_cast(head_dim_qk), static_cast(head_dim_v), d.max_b, d.max_t_q, d.max_t_kv, d.num_pages_k, - d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, - d.max_pages_per_seq_v, d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, - return_max_logit, /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, + d.num_pages_v, d.page_size_k, d.page_size_v, d.max_pages_per_seq_k, d.max_pages_per_seq_v, + d.bias_b, d.bias_h, d.bias_sq, d.bias_skv, is_training, return_max_logit, + /*scaling_factor=*/1.0f, p_dropout, qkv_layout, o_format, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrBias=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrS1=*/nullptr, /*devPtrS2=*/nullptr, /*devPtrO=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, /*devPtrPageTableK=*/nullptr, /*devPtrPageTableV=*/nullptr, - /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, - get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), + /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; } catch (const std::exception &e) { @@ -1432,11 +1431,10 @@ cudnn_frontend::error_t is_supported_f16_bwd( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType q_dtype, cudnnHandle_t handle) { - const ProbeDims d = compute_probe_dims(static_cast(batch), - static_cast(num_attn_heads), - static_cast(max_seqlen_q), - static_cast(max_seqlen_kv), qkv_layout, - bias_type); + const ProbeDims d = + compute_probe_dims(static_cast(batch), static_cast(num_attn_heads), + static_cast(max_seqlen_q), static_cast(max_seqlen_kv), + qkv_layout, bias_type); const NVTE_QKV_Format o_format = nvte_get_q_format(qkv_layout); const NVTE_QKV_Format do_format = o_format; const NVTE_QKV_Layout dqkv_layout = qkv_layout; @@ -1457,8 +1455,8 @@ cudnn_frontend::error_t is_supported_f16_bwd( /*devPtrdBias=*/nullptr, /*devPtrdSoftmaxOffset=*/nullptr, /*devPtrDropoutSeed=*/nullptr, /*devPtrDropoutOffset=*/nullptr, /*devPtrCuSeqlensQ=*/nullptr, /*devPtrCuSeqlensKV=*/nullptr, - /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, - get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, + /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), + /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; } catch (const std::exception &e) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index c9f7a9ee76..8a152cf489 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -3014,21 +3014,21 @@ cudnn_frontend::error_t is_supported_fp8_fwd( static_cast(num_gqa_groups), static_cast(max_seqlen_q), static_cast(max_seqlen_kv), static_cast(head_dim_qk), static_cast(head_dim_v), is_training, /*scaling_factor=*/1.0f, p_dropout, - qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, bottom_right_diagonal, + qkv_layout, /*o_format=*/qkv_format, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, bottom_right_diagonal, /*devPtrQ=*/nullptr, /*devPtrK=*/nullptr, /*devPtrV=*/nullptr, /*devPtrSoftmaxOffset=*/nullptr, /*devPtrM=*/nullptr, /*devPtrO=*/nullptr, /*devPtrDescaleQ=*/nullptr, /*devPtrDescaleK=*/nullptr, /*devPtrDescaleV=*/nullptr, /*devPtrDescaleS=*/nullptr, /*devPtrScaleS=*/nullptr, /*devPtrScaleO=*/nullptr, /*devPtrAmaxO=*/nullptr, /*devPtrAmaxS=*/nullptr, /*devPtrcuSeqlensQ=*/nullptr, /*devPtrcuSeqlensKV=*/nullptr, /*devPtrDropoutSeed=*/nullptr, - /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), - get_cudnn_fe_dtype(o_dtype), scaling_mode, + /*devPtrDropoutOffset=*/nullptr, get_cudnn_fe_dtype(q_dtype), get_cudnn_fe_dtype(o_dtype), + scaling_mode, /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; - } catch (const std::exception &e) { + } catch (const std::exception& e) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; } catch (...) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, @@ -3084,7 +3084,7 @@ cudnn_frontend::error_t is_supported_fp8_bwd( /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); return {cudnn_frontend::error_code_t::OK, ""}; - } catch (const std::exception &e) { + } catch (const std::exception& e) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; } catch (...) { return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index d67cd4a6b9..256ede6e55 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -58,8 +58,8 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit, cuda_graph, deterministic, handle, /*out_status=*/nullptr); + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_logit, + cuda_graph, deterministic, handle, /*out_status=*/nullptr); return fused_attention_backend; } From 42bcd89036fc8d093a192985304018c4497b22ab Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 7 May 2026 14:11:46 -0700 Subject: [PATCH 3/3] replace code+string with string only Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 98 +++++++++---------- .../fused_attn_f16_arbitrary_seqlen.cu | 18 ++-- .../fused_attn_f16_arbitrary_seqlen.h | 19 ++-- .../common/fused_attn/fused_attn_fp8.cu | 24 ++--- .../common/fused_attn/fused_attn_fp8.h | 18 ++-- .../include/transformer_engine/fused_attn.h | 37 +++---- .../jax/csrc/extensions/attention.cpp | 6 +- .../pytorch/csrc/extensions/attention.cpp | 2 +- 8 files changed, 103 insertions(+), 119 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 95405c0d6f..961b503c1c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -228,28 +228,17 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { namespace { -// Per-thread storage for the message string handed back through -// NVTEFusedAttnBackendStatus::message. Re-used (cleared + re-populated) on every call to -// nvte_get_fused_attn_backend on this thread, which is exactly the lifetime documented in the -// public header. -thread_local std::string g_fused_attn_backend_status_buffer; - -// Apply (code, msg) to *out_status (if non-null), routing the message through the -// thread-local buffer so the returned `const char*` outlives this function call. -void set_status(NVTEFusedAttnBackendStatus *out_status, cudnn_frontend::error_code_t code, - const std::string &message) { - if (out_status == nullptr) return; - g_fused_attn_backend_status_buffer = message; - out_status->code = static_cast(code); - out_status->message = g_fused_attn_backend_status_buffer.c_str(); -} - -void set_status(NVTEFusedAttnBackendStatus *out_status, const cudnn_frontend::error_t &err) { - set_status(out_status, err.code, err.err_msg); -} - -void set_ok(NVTEFusedAttnBackendStatus *out_status) { - set_status(out_status, cudnn_frontend::error_code_t::OK, ""); +// Per-thread storage for the diagnostic string handed back through *out_reason. Re-used +// (cleared + re-populated) on every call to nvte_get_fused_attn_backend on this thread, +// which is exactly the lifetime documented in the public header. +thread_local std::string g_fused_attn_backend_reason_buffer; + +// Stash `reason` in the thread-local buffer and (if non-null) point *out_reason at it, +// so the returned `const char*` outlives this function call. +void set_reason(const char **out_reason, const std::string &reason) { + if (out_reason == nullptr) return; + g_fused_attn_backend_reason_buffer = reason; + *out_reason = g_fused_attn_backend_reason_buffer.c_str(); } } // namespace @@ -271,9 +260,9 @@ void set_ok(NVTEFusedAttnBackendStatus *out_status) { // executor cache-hits on. // 3. Return the selected backend, or NVTE_No_Backend if any probe rejects. // -// When `out_status` is non-null, it is filled with a code + message describing the -// rejection (or {OK, ""} on success). TE post-filter rejections synthesize an -// INVALID_VALUE entry; probe rejections forward the cuDNN-FE / NVTE_CHECK error verbatim. +// When `out_reason` is non-null, it is set to "" on success or to a tagged diagnostic +// string on rejection. TE post-filter rejections are tagged "[INVALID_VALUE] ..."; +// probe rejections forward the probe's tagged string verbatim. NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, NVTEScalingMode scaling_mode, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -281,11 +270,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, - NVTEFusedAttnBackendStatus *out_status) { + const char **out_reason) { using namespace transformer_engine; - // Initialize to OK so callers get a clean status on the success path without us having to + // Initialize to "" so callers get a clean status on the success path without us having to // remember to set it at every return. - set_ok(out_status); + set_reason(out_reason, ""); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -300,8 +289,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( layout_group, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v) == DType::kInt64); if (requires_64bit_ragged_offset && cudnn_runtime_version < 90500) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "Configuration requires 64-bit ragged offsets, which require cuDNN >= 9.5."); + set_reason(out_reason, + "[INVALID_VALUE] Configuration requires 64-bit ragged offsets, which require " + "cuDNN >= 9.5."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -310,8 +300,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "THD-format attention requires a padding-style mask " + set_reason(out_reason, + "[INVALID_VALUE] THD-format attention requires a padding-style mask " "(PADDING / PADDING_CAUSAL / PADDING_CAUSAL_BOTTOM_RIGHT)."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -324,8 +314,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "Known cuDNN <= 9.15 capture quirk: training + bshd/sbhd + " + set_reason(out_reason, + "[INVALID_VALUE] Known cuDNN <= 9.15 capture quirk: training + bshd/sbhd + " "max_seqlen_kv % 128 != 0 + cuda_graph + non-padding mask is unsupported."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -346,34 +336,34 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (is_fp8) { // TE-only FP8 post-filters: no 64-bit ragged offsets, no max-logit output. if (requires_64bit_ragged_offset) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention does not support 64-bit ragged offsets."); + set_reason(out_reason, + "[INVALID_VALUE] FP8 fused attention does not support 64-bit ragged offsets."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (return_max_logit) { - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention does not support return_max_logit."); + set_reason(out_reason, + "[INVALID_VALUE] FP8 fused attention does not support return_max_logit."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } const DType q_t = static_cast(q_dtype); const DType o_t = static_cast(o_dtype); - auto fwd_status = is_supported_fp8_fwd( + std::string fwd_reason = is_supported_fp8_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, o_t, scaling_mode, handle); - if (fwd_status.is_bad()) { - set_status(out_status, fwd_status); + if (!fwd_reason.empty()) { + set_reason(out_reason, fwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { - auto bwd_status = is_supported_fp8_bwd( + std::string bwd_reason = is_supported_fp8_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, o_t, scaling_mode, handle); - if (bwd_status.is_bad()) { - set_status(out_status, bwd_status); + if (!bwd_reason.empty()) { + set_reason(out_reason, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } @@ -382,31 +372,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if (is_f16_or_bf16) { const DType q_t = static_cast(q_dtype); - auto fwd_status = is_supported_f16_fwd( + std::string fwd_reason = is_supported_f16_fwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, return_max_logit, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, q_t, handle); - if (fwd_status.is_bad()) { - set_status(out_status, fwd_status); + if (!fwd_reason.empty()) { + set_reason(out_reason, fwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } if (is_training) { - auto bwd_status = is_supported_f16_bwd( + std::string bwd_reason = is_supported_f16_bwd( probe_batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, probe_bottom_right_diagonal, deterministic, q_t, handle); - if (bwd_status.is_bad()) { - set_status(out_status, bwd_status); + if (!bwd_reason.empty()) { + set_reason(out_reason, bwd_reason); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } } return NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - set_status(out_status, cudnn_frontend::error_code_t::INVALID_VALUE, - "Unsupported Q dtype for fused attention " + set_reason(out_reason, + "[INVALID_VALUE] Unsupported Q dtype for fused attention " "(only FP16/BF16/FP8_E4M3/FP8_E5M2 are routable)."); return NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -500,7 +490,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso is_training, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, /*deterministic=*/false, handle, - /*out_status=*/nullptr); + /*out_reason=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, @@ -589,7 +579,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso /*is_training=*/true, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, /*return_max_logit=*/false, cuda_graph, deterministic, - handle, /*out_status=*/nullptr); + handle, /*out_reason=*/nullptr); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1ced84755c..70094c4e93 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1385,7 +1385,7 @@ ProbeDims compute_probe_dims(int64_t batch, int64_t num_attn_heads, int64_t max_ } } // namespace -cudnn_frontend::error_t is_supported_f16_fwd( +std::string is_supported_f16_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -1416,16 +1416,15 @@ cudnn_frontend::error_t is_supported_f16_fwd( /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception &e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_f16_fwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_f16_fwd: unknown failure"; } } -cudnn_frontend::error_t is_supported_f16_bwd( +std::string is_supported_f16_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -1458,12 +1457,11 @@ cudnn_frontend::error_t is_supported_f16_bwd( /*devPtrSeqOffsetsQ=*/nullptr, /*devPtrSeqOffsetsKV=*/nullptr, get_cudnn_fe_dtype(q_dtype), /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception &e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_f16_bwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_f16_bwd: unknown failure"; } } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 38cf48c1f0..0eabe3e8dc 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -12,7 +12,8 @@ #define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_ARBITRARY_SEQLEN_H_ #include -#include + +#include #include "common/common.h" #include "transformer_engine/fused_attn.h" @@ -50,11 +51,15 @@ void fused_attn_arbitrary_seqlen_bwd( // Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> // check_support -> build_plans) for an F16/BF16 forward graph with the given configuration. -// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, -// the built graph is inserted into the same thread-local cache used by -// fused_attn_arbitrary_seqlen_fwd_impl, so the executor cache-hits on matching descriptors. -// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. -cudnn_frontend::error_t is_supported_f16_fwd( +// Returns an empty string iff the graph compiles end-to-end; on OK the built graph is +// inserted into the same thread-local cache used by fused_attn_arbitrary_seqlen_fwd_impl, +// so the executor cache-hits on matching descriptors. +// +// On rejection, returns a non-empty diagnostic of the form +// "[] " +// where is a stable tag mirroring cudnn_frontend::error_code_t names +// (e.g. GRAPH_NOT_SUPPORTED for cuDNN-FE rejections forwarded from the support chain). +std::string is_supported_f16_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, bool return_max_logit, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -62,7 +67,7 @@ cudnn_frontend::error_t is_supported_f16_fwd( int64_t window_size_right, bool bottom_right_diagonal, DType q_dtype, cudnnHandle_t handle); // Probe: same as above for the F16/BF16 backward graph. -cudnn_frontend::error_t is_supported_f16_bwd( +std::string is_supported_f16_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 8a152cf489..27bd0af3f3 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2992,7 +2992,7 @@ void fused_attn_fp8_bwd( } } -cudnn_frontend::error_t is_supported_fp8_fwd( +std::string is_supported_fp8_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -3004,8 +3004,7 @@ cudnn_frontend::error_t is_supported_fp8_fwd( const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return {cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + return "[INVALID_VALUE] FP8 fused attention only supports BSHD/SBHD/BHSD layouts."; } size_t workspace_size = 0; try { @@ -3027,16 +3026,15 @@ cudnn_frontend::error_t is_supported_fp8_fwd( /*qkv_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception& e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_fp8_fwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_fp8_fwd: unknown failure"; } } -cudnn_frontend::error_t is_supported_fp8_bwd( +std::string is_supported_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -3046,8 +3044,7 @@ cudnn_frontend::error_t is_supported_fp8_bwd( const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && qkv_format != NVTE_QKV_Format::NVTE_BHSD) { - return {cudnn_frontend::error_code_t::INVALID_VALUE, - "FP8 fused attention only supports BSHD/SBHD/BHSD layouts."}; + return "[INVALID_VALUE] FP8 fused attention only supports BSHD/SBHD/BHSD layouts."; } // For FP8 bwd, dO data type matches O data type and dQKV data type matches Q data type // (this mirrors the assumption used by callers of fused_attn_fp8_bwd in TE). @@ -3083,12 +3080,11 @@ cudnn_frontend::error_t is_supported_fp8_bwd( /*do_scale_inv_format=*/NVTE_QKV_Format::NVTE_QKV_Format_NOT_SET, /*workspace=*/nullptr, &workspace_size, /*stream=*/static_cast(0), handle); - return {cudnn_frontend::error_code_t::OK, ""}; + return ""; } catch (const std::exception& e) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, e.what()}; + return std::string("[GRAPH_NOT_SUPPORTED] ") + e.what(); } catch (...) { - return {cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED, - "is_supported_fp8_bwd: unknown failure"}; + return "[GRAPH_NOT_SUPPORTED] is_supported_fp8_bwd: unknown failure"; } } diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 5c7f11d80e..f91cdcf291 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -8,7 +8,7 @@ * \brief Functions for fused attention for FP8 with seqlen <= 512 */ -#include +#include #include "transformer_engine/fused_attn.h" #include "transformer_engine/transformer_engine.h" @@ -44,10 +44,15 @@ void fused_attn_fp8_bwd( // Probe: drives cuDNN-FE (validate -> build_operation_graph -> create_execution_plans -> // check_support -> build_plans) for an FP8 forward graph with the given configuration. -// Returns the cuDNN-FE status: error_code_t::OK iff the graph compiles end-to-end. On OK, -// the built graph is inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl. -// On rejection, err_msg contains the underlying cuDNN-FE / NVTE_CHECK message. -cudnn_frontend::error_t is_supported_fp8_fwd( +// Returns an empty string iff the graph compiles end-to-end; on OK the built graph is +// inserted into the same thread-local cache used by fused_attn_fp8_fwd_impl. +// +// On rejection, returns a non-empty diagnostic of the form +// "[] " +// where mirrors cudnn_frontend::error_code_t names (INVALID_VALUE for the +// FP8-only layout pre-filter, GRAPH_NOT_SUPPORTED for cuDNN-FE rejections forwarded +// from the support chain). +std::string is_supported_fp8_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, bool is_training, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, @@ -56,12 +61,11 @@ cudnn_frontend::error_t is_supported_fp8_fwd( cudnnHandle_t handle); // Probe: same as above for the FP8 backward graph. -cudnn_frontend::error_t is_supported_fp8_bwd( +std::string is_supported_fp8_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, DType q_dtype, DType o_dtype, NVTEScalingMode scaling_mode, cudnnHandle_t handle); ->>>>>>> c9006435 (refactor nvte_get_fused_attn_backend with FE calls) } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 787e97d628..bbcdf08995 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -198,22 +198,6 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); -/*! \struct NVTEFusedAttnBackendStatus - * \brief Diagnostic info from \c nvte_get_fused_attn_backend. - * - * Filled by \c nvte_get_fused_attn_backend when the caller passes a non-NULL pointer. - * When the routing decision is supported, \c code is 0 and \c message is the empty - * string. When the routing rejects the configuration, \c code is the underlying - * cuDNN-FE \c cudnn_frontend::error_code_t cast to \c int (TE-synthesized post-filter - * rejections use \c INVALID_VALUE), and \c message is a null-terminated human-readable - * reason that points into per-thread storage owned by TE. The pointer is valid only - * until the next call to \c nvte_get_fused_attn_backend on the same thread. - */ -typedef struct NVTEFusedAttnBackendStatus { - int code; - const char *message; -} NVTEFusedAttnBackendStatus; - /*! \brief Get fused attention backend based on input parameters. * * Authoritative routing: when a non-NVTE_No_Backend value is returned, the configuration @@ -249,12 +233,19 @@ typedef struct NVTEFusedAttnBackendStatus { * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] deterministic Whether determinism is required or not. * \param[in] handle cuDNN handle used for the support chain. Required. - * \param[out] out_status Optional. When non-NULL, populated with a code + - * message describing why the configuration was - * rejected (NVTE_No_Backend) or with code=0 and - * message="" on success. The message buffer lives in - * thread-local storage and is overwritten on every - * call on the same thread. + * \param[out] out_reason Optional. When non-NULL, set to a null-terminated + * diagnostic string describing why the configuration + * was rejected (NVTE_No_Backend) or set to "" on + * success. Rejection messages are tagged with a + * stable category prefix that mirrors + * \c cudnn_frontend::error_code_t, e.g. + * \c "[INVALID_VALUE] ..." for TE post-filter + * rejections and FP8 layout pre-filter rejections, + * \c "[GRAPH_NOT_SUPPORTED] ..." for cuDNN-FE + * rejections forwarded from the support chain. The + * pointer points into per-thread storage owned by TE + * and is valid only until the next call to + * \c nvte_get_fused_attn_backend on the same thread. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTEDType o_dtype, @@ -263,7 +254,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, - NVTEFusedAttnBackendStatus *out_status); + const char **out_reason); /*! \brief Compute dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index c6a8897089..669570daa5 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -25,7 +25,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, /*cuda_graph=*/false, deterministic, handle, - /*out_status=*/nullptr); + /*out_reason=*/nullptr); return backend; } @@ -282,7 +282,7 @@ static void FusedAttnForwardImpl( static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_fwd, /*out_status=*/nullptr); + /*cuda_graph=*/false, deterministic, _handle_fwd, /*out_reason=*/nullptr); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -560,7 +560,7 @@ static void FusedAttnBackwardImpl( static_cast(dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, /*return_max_logit=*/false, - /*cuda_graph=*/false, deterministic, _handle_bwd, /*out_status=*/nullptr); + /*cuda_graph=*/false, deterministic, _handle_bwd, /*out_reason=*/nullptr); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias, softmax_offset); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 256ede6e55..3af4ba3831 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -59,7 +59,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( static_cast(q_dtype), NVTE_DELAYED_TENSOR_SCALING, qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, return_max_logit, - cuda_graph, deterministic, handle, /*out_status=*/nullptr); + cuda_graph, deterministic, handle, /*out_reason=*/nullptr); return fused_attention_backend; }