Skip to content

Commit 8f40cab

Browse files
fixes from the feedback
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1 parent b5b2b9d commit 8f40cab

3 files changed

Lines changed: 8 additions & 8 deletions

File tree

transformer_engine/common/fused_attn/utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,23 +118,23 @@ struct FADescriptor_v1 {
118118
cudnn_frontend::DataType_t o_tensor_type;
119119
cudnn_frontend::DataType_t do_tensor_type;
120120
cudnn_frontend::DataType_t dqkv_tensor_type;
121-
bool generate_max_sum_exp;
121+
bool return_max_logit;
122122

123123
bool operator<(const FADescriptor_v1 &rhs) const {
124124
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
125125
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq,
126126
bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type,
127127
softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
128128
deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type,
129-
dqkv_tensor_type, generate_max_sum_exp) <
129+
dqkv_tensor_type, return_max_logit) <
130130
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
131131
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
132132
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv,
133133
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
134134
rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right,
135135
rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type,
136136
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
137-
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
137+
rhs.dqkv_tensor_type, rhs.return_max_logit);
138138
}
139139
};
140140

transformer_engine/common/include/transformer_engine/fused_attn.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
206206
* \param[in] head_dim_v The head dimension of V.
207207
* \param[in] window_size_left Sliding window size (the left half).
208208
* \param[in] window_size_right Sliding window size (the right half).
209-
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
209+
* \param[in] return_max_logit Whether to produce Max along with Stats.
210210
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
211211
* \param[in] deterministic Whether determinism is required or not.
212212
*/
@@ -260,7 +260,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
260260
* \param[in] max_seqlen Max sequence length used for computing,
261261
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
262262
* \param[in] is_training Whether this is in training mode or inference.
263-
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
263+
* \param[in] return_max_logit Whether to produce Max along with Stats.
264264
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
265265
* \param[in] attn_scale Scaling factor for Q * K.T.
266266
* \param[in] dropout Dropout probability.
@@ -400,7 +400,7 @@ void nvte_fused_attn_bwd_qkvpacked(
400400
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
401401
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
402402
* \param[in] is_training Whether this is in training mode or inference.
403-
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
403+
* \param[in] return_max_logit Whether to produce Max along with Stats.
404404
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
405405
* \param[in] attn_scale Scaling factor for Q * K.T.
406406
* \param[in] dropout Dropout probability.
@@ -553,7 +553,7 @@ void nvte_fused_attn_bwd_kvpacked(
553553
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
554554
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
555555
* \param[in] is_training Whether this is in training mode or inference.
556-
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
556+
* \param[in] return_max_logit Whether to produce Max along with Stats.
557557
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
558558
* \param[in] attn_scale Scaling factor for Q * K.T.
559559
* \param[in] dropout Dropout probability.

transformer_engine/pytorch/csrc/extensions/attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ std::vector<py::object> fused_attn_fwd(
263263
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
264264
size_t i = 0;
265265
at::Tensor output_tensor;
266-
// intermediate softmax tensor, S (or `Stats`)
266+
// intermediate softmax tensor, S or M (for fp8)
267267
output_tensor =
268268
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
269269
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);

0 commit comments

Comments
 (0)