Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 213 additions & 192 deletions transformer_engine/common/common.h

Large diffs are not rendered by default.

94 changes: 90 additions & 4 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,88 @@
#include "fused_attn_fp8.h"
#include "utils.h"

namespace transformer_engine {

std::string to_string(NVTE_QKV_Layout layout) {
switch (layout) {
case NVTE_SB3HD:
return "NVTE_SB3HD";
case NVTE_SBH3D:
return "NVTE_SBH3D";
case NVTE_SBHD_SB2HD:
return "NVTE_SBHD_SB2HD";
case NVTE_SBHD_SBH2D:
return "NVTE_SBHD_SBH2D";
case NVTE_SBHD_SBHD_SBHD:
return "NVTE_SBHD_SBHD_SBHD";
case NVTE_BS3HD:
return "NVTE_BS3HD";
case NVTE_BSH3D:
return "NVTE_BSH3D";
case NVTE_BSHD_BS2HD:
return "NVTE_BSHD_BS2HD";
case NVTE_BSHD_BSH2D:
return "NVTE_BSHD_BSH2D";
case NVTE_BSHD_BSHD_BSHD:
return "NVTE_BSHD_BSHD_BSHD";
case NVTE_T3HD:
return "NVTE_T3HD";
case NVTE_TH3D:
return "NVTE_TH3D";
case NVTE_THD_T2HD:
return "NVTE_THD_T2HD";
case NVTE_THD_TH2D:
return "NVTE_THD_TH2D";
case NVTE_THD_THD_THD:
return "NVTE_THD_THD_THD";
case NVTE_SBHD_BSHD_BSHD:
return "NVTE_SBHD_BSHD_BSHD";
case NVTE_BSHD_SBHD_SBHD:
return "NVTE_BSHD_SBHD_SBHD";
case NVTE_THD_BSHD_BSHD:
return "NVTE_THD_BSHD_BSHD";
case NVTE_THD_SBHD_SBHD:
return "NVTE_THD_SBHD_SBHD";
case NVTE_Paged_KV_BSHD_BSHD_BSHD:
return "NVTE_Paged_KV_BSHD_BSHD_BSHD";
case NVTE_Paged_KV_BSHD_SBHD_SBHD:
return "NVTE_Paged_KV_BSHD_SBHD_SBHD";
case NVTE_Paged_KV_SBHD_BSHD_BSHD:
return "NVTE_Paged_KV_SBHD_BSHD_BSHD";
case NVTE_Paged_KV_SBHD_SBHD_SBHD:
return "NVTE_Paged_KV_SBHD_SBHD_SBHD";
case NVTE_Paged_KV_THD_BSHD_BSHD:
return "NVTE_Paged_KV_THD_BSHD_BSHD";
case NVTE_Paged_KV_THD_SBHD_SBHD:
return "NVTE_Paged_KV_THD_SBHD_SBHD";
default:
return "UNKNOWN_QKV_LAYOUT(" + std::to_string(static_cast<int>(layout)) + ")";
}
}

std::string to_string(NVTE_QKV_Format format) {
switch (format) {
case NVTE_SBHD:
return "NVTE_SBHD";
case NVTE_BSHD:
return "NVTE_BSHD";
case NVTE_THD:
return "NVTE_THD";
case NVTE_BSHD_2SBHD:
return "NVTE_BSHD_2SBHD";
case NVTE_SBHD_2BSHD:
return "NVTE_SBHD_2BSHD";
case NVTE_THD_2BSHD:
return "NVTE_THD_2BSHD";
case NVTE_THD_2SBHD:
return "NVTE_THD_2SBHD";
default:
return "UNKNOWN_QKV_FORMAT(" + std::to_string(static_cast<int>(format)) + ")";
}
}

} // namespace transformer_engine

// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
switch (qkv_layout) {
Expand Down Expand Up @@ -50,7 +132,8 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD;
default:
NVTE_ERROR("qkv_layout not supported!");
NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout),
" in nvte_get_qkv_layout_group.");
}
}

Expand Down Expand Up @@ -90,7 +173,8 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD:
return NVTE_QKV_Format::NVTE_THD_2SBHD;
default:
NVTE_ERROR("qkv_layout not supported!");
NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout),
" in nvte_get_qkv_format.");
}
}

Expand All @@ -109,7 +193,8 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Format::NVTE_THD_2SBHD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format),
" in nvte_get_q_format.");
}
}

Expand All @@ -128,7 +213,8 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
case NVTE_QKV_Format::NVTE_THD:
return NVTE_QKV_Format::NVTE_THD;
default:
NVTE_ERROR("qkv_layout not supported!");
NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format),
" in nvte_get_kv_format.");
}
}

Expand Down
79 changes: 41 additions & 38 deletions transformer_engine/common/fused_router/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,46 +216,49 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i
}

// Current TE only support float32/bf16/fp16, float64 probs should be considered in the future
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat16: { \
using type = fp16; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Unsupported router probs dtype ", to_string(static_cast<DType>(dtype)), \
". Expected one of: Float32, Float16, BFloat16."); \
}

#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt64: { \
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type."); \
#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kInt32: { \
using type = int32_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kInt64: { \
using type = int64_t; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Unsupported router index dtype ", to_string(static_cast<DType>(dtype)), \
". Expected one of: Int32, Int64, BFloat16, " \
"Float32."); \
}
} // namespace transformer_engine
#endif
6 changes: 3 additions & 3 deletions transformer_engine/common/gemm/cutlass_grouped_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -326,17 +326,17 @@ void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,

// Check can implement the kernel.
if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM");
NVTE_ERROR("Failed to implement CUTLASS Grouped GEMM with ", num_gemms, " GEMMs");
}

// Initialize the kernel.
if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
NVTE_ERROR("Failed to initialize CUTLASS Grouped GEMM with ", num_gemms, " GEMMs");
}

// Execute the kernel in the current stream.
if (gemm.run(stream) != cutlass::Status::kSuccess) {
NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
NVTE_ERROR("Failed to run CUTLASS Grouped GEMM with ", num_gemms, " GEMMs");
}
}

Expand Down
8 changes: 6 additions & 2 deletions transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ void TeNormalizationPlan<BackwardKernelParams>::execute(Tensor* z, void* x_dptr,
void* beta_dptr, void* mean_dptr,
void* eps_dptr, void* rsigma_dptr,
void* workspace_dptr, cudaStream_t stream) {
NVTE_ERROR("Backward normalization should not call the forward execute function!");
NVTE_ERROR(
"Backward normalization should not call the forward execute function. "
"Use the backward-specific execute overload instead.");
}

template <typename KernelParamsType>
Expand Down Expand Up @@ -165,7 +167,9 @@ void TeNormalizationPlan<ForwardKernelParams>::execute(void* x_dptr, void* gamma
void* dx_dptr, void* dz_dptr, void* add_dptr,
void* dbeta_dptr, void* dgamma_dptr,
void* workspace_dptr, cudaStream_t stream) {
NVTE_ERROR("Forward normalization should not call the backward execute function!");
NVTE_ERROR(
"Forward normalization should not call the backward execute function. "
"Use the forward-specific execute overload instead.");
}

template <>
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor");
NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_shape");
}

// Determine tensor shape depending on tensor format
Expand All @@ -662,7 +662,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
NVTE_ERROR("Invalid tensor");
NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_columnwise_shape");
}
const std::vector<size_t> &shape = t->columnwise_data.shape;
return nvte_make_shape(shape.data(), shape.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,8 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size
std::is_same_v<OutputType, __nv_fp8_e5m2>) {
dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else {
NVTE_CHECK(false, "Invalid Output type (must be FP8).");
NVTE_ERROR(
"Invalid output type for blockwise transpose (must be FP8: Float8E4M3 or Float8E5M2).");
}

CUtensorMap tensor_map_output_trans{};
Expand Down
Loading
Loading