From 7dc87c0bbbde9414c2f5f74136596bb4ca1cee69 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 25 Feb 2026 11:46:52 -0800 Subject: [PATCH 1/3] Added better error messages Signed-off-by: Przemek Tredak --- transformer_engine/common/common.h | 54 ++++-- .../common/fused_attn/fused_attn.cpp | 76 +++++++- .../common/fused_router/utils.h | 9 +- .../common/gemm/cutlass_grouped_gemm.cuh | 9 +- .../common/normalization/common.cpp | 6 +- .../common/transformer_engine.cpp | 4 +- .../quantize_transpose_square_blockwise.cu | 2 +- .../jax/cpp_extensions/attention.py | 42 ++--- transformer_engine/jax/cpp_extensions/gemm.py | 34 ++-- .../jax/cpp_extensions/normalization.py | 40 ++-- transformer_engine/jax/flax/transformer.py | 65 ++++--- transformer_engine/jax/layernorm.py | 6 +- .../pytorch/cpp_extensions/fused_attn.py | 82 +++++---- transformer_engine/pytorch/cpu_offload.py | 28 ++- transformer_engine/pytorch/csrc/common.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 6 +- .../pytorch/custom_recipes/gemm.py | 26 +-- .../custom_recipes/quantization_nvfp4.py | 30 ++- transformer_engine/pytorch/distributed.py | 124 ++++++++++--- transformer_engine/pytorch/graph.py | 43 ++++- transformer_engine/pytorch/module/base.py | 173 ++++++++++++------ transformer_engine/pytorch/permutation.py | 173 ++++++++++++++---- transformer_engine/pytorch/quantization.py | 3 +- transformer_engine/pytorch/tensor/utils.py | 96 ++++++++-- transformer_engine/pytorch/transformer.py | 92 +++++++--- transformer_engine/pytorch/utils.py | 80 +++++--- 26 files changed, 923 insertions(+), 384 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0c722634f3..4e9800a69d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -38,6 +38,8 @@ namespace transformer_engine { std::string to_string(const DType type); std::string to_string(const NVTEScalingMode &mode); +inline std::string to_string_like(const DType &val) { return to_string(val); } + inline bool is_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } @@ -649,7 +651,11 @@ struct TypeInfo { } break; \ SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Byte, Int32, Int64, Float32, " \ + "Float16, BFloat16, Float8E4M3, Float8E5M2, " \ + "Float8E8M0, Float4E2M1."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \ @@ -676,7 +682,10 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E4M3, Float8E5M2."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ @@ -703,7 +712,10 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported output dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E5M2, Float8E4M3."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ @@ -722,7 +734,9 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \ @@ -737,7 +751,9 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type, expected Float32 or BFloat16."); \ + NVTE_ERROR("Unsupported dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float32, BFloat16."); \ } // Add a pack_size argument to select the packed type for FP4 @@ -749,7 +765,9 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected: Float4E2M1."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ @@ -764,7 +782,9 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float8E5M2, Float8E4M3."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ @@ -784,13 +804,20 @@ struct TypeInfo { } break; \ case DType::kFloat8E5M2: \ case DType::kFloat8E4M3: { \ - NVTE_ERROR("FP8 type not instantiated for input."); \ + NVTE_ERROR("FP8 dtype ", \ + to_string(static_cast(dtype)), \ + " is not instantiated for input. " \ + "Expected one of: Float32, Float16, BFloat16."); \ } break; \ case DType::kFloat4E2M1: { \ - NVTE_ERROR("FP4 type not instantiated for input."); \ + NVTE_ERROR("FP4 dtype Float4E2M1 is not instantiated " \ + "for input. Expected one of: Float32, Float16, " \ + "BFloat16."); \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported input dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ @@ -807,7 +834,9 @@ struct TypeInfo { break; \ } \ default: \ - NVTE_ERROR("Invalid type for 16 bit."); \ + NVTE_ERROR("Unsupported 16-bit dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float16, BFloat16."); \ } #define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ @@ -821,7 +850,8 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: { \ - NVTE_ERROR("Invalid size of the MX scaling factor."); \ + NVTE_ERROR("Unsupported MX scaling factor dimension ", \ + SCALE_DIM, ". Expected one of: 1, 32."); \ } \ } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index b5679280c6..1c491ab807 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -15,6 +15,54 @@ #include "fused_attn_fp8.h" #include "utils.h" +namespace transformer_engine { + +std::string to_string(NVTE_QKV_Layout layout) { + switch (layout) { + case NVTE_SB3HD: return "NVTE_SB3HD"; + case NVTE_SBH3D: return "NVTE_SBH3D"; + case NVTE_SBHD_SB2HD: return "NVTE_SBHD_SB2HD"; + case NVTE_SBHD_SBH2D: return "NVTE_SBHD_SBH2D"; + case NVTE_SBHD_SBHD_SBHD: return "NVTE_SBHD_SBHD_SBHD"; + case NVTE_BS3HD: return "NVTE_BS3HD"; + case NVTE_BSH3D: return "NVTE_BSH3D"; + case NVTE_BSHD_BS2HD: return "NVTE_BSHD_BS2HD"; + case NVTE_BSHD_BSH2D: return "NVTE_BSHD_BSH2D"; + case NVTE_BSHD_BSHD_BSHD: return "NVTE_BSHD_BSHD_BSHD"; + case NVTE_T3HD: return "NVTE_T3HD"; + case NVTE_TH3D: return "NVTE_TH3D"; + case NVTE_THD_T2HD: return "NVTE_THD_T2HD"; + case NVTE_THD_TH2D: return "NVTE_THD_TH2D"; + case NVTE_THD_THD_THD: return "NVTE_THD_THD_THD"; + case NVTE_SBHD_BSHD_BSHD: return "NVTE_SBHD_BSHD_BSHD"; + case NVTE_BSHD_SBHD_SBHD: return "NVTE_BSHD_SBHD_SBHD"; + case NVTE_THD_BSHD_BSHD: return "NVTE_THD_BSHD_BSHD"; + case NVTE_THD_SBHD_SBHD: return "NVTE_THD_SBHD_SBHD"; + case NVTE_Paged_KV_BSHD_BSHD_BSHD: return "NVTE_Paged_KV_BSHD_BSHD_BSHD"; + case NVTE_Paged_KV_BSHD_SBHD_SBHD: return "NVTE_Paged_KV_BSHD_SBHD_SBHD"; + case NVTE_Paged_KV_SBHD_BSHD_BSHD: return "NVTE_Paged_KV_SBHD_BSHD_BSHD"; + case NVTE_Paged_KV_SBHD_SBHD_SBHD: return "NVTE_Paged_KV_SBHD_SBHD_SBHD"; + case NVTE_Paged_KV_THD_BSHD_BSHD: return "NVTE_Paged_KV_THD_BSHD_BSHD"; + case NVTE_Paged_KV_THD_SBHD_SBHD: return "NVTE_Paged_KV_THD_SBHD_SBHD"; + default: return "UNKNOWN_QKV_LAYOUT(" + std::to_string(static_cast(layout)) + ")"; + } +} + +std::string to_string(NVTE_QKV_Format format) { + switch (format) { + case NVTE_SBHD: return "NVTE_SBHD"; + case NVTE_BSHD: return "NVTE_BSHD"; + case NVTE_THD: return "NVTE_THD"; + case NVTE_BSHD_2SBHD: return "NVTE_BSHD_2SBHD"; + case NVTE_SBHD_2BSHD: return "NVTE_SBHD_2BSHD"; + case NVTE_THD_2BSHD: return "NVTE_THD_2BSHD"; + case NVTE_THD_2SBHD: return "NVTE_THD_2SBHD"; + default: return "UNKNOWN_QKV_FORMAT(" + std::to_string(static_cast(format)) + ")"; + } +} + +} // namespace transformer_engine + namespace { // Helper function to create a tensor view with modified shape and optional pointer offset transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *source, @@ -118,7 +166,9 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_layout ", + transformer_engine::to_string(qkv_layout), + " in nvte_get_qkv_layout_group."); } } @@ -158,7 +208,9 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_layout ", + transformer_engine::to_string(qkv_layout), + " in nvte_get_qkv_format."); } } @@ -177,7 +229,9 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_format ", + transformer_engine::to_string(qkv_format), + " in nvte_get_q_format."); } } @@ -196,7 +250,9 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("qkv_layout not supported!"); + NVTE_ERROR("Unsupported qkv_format ", + transformer_engine::to_string(qkv_format), + " in nvte_get_kv_format."); } } @@ -549,7 +605,8 @@ void nvte_fused_attn_fwd_qkvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { h = input_QKV->data.shape[ndim - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts, got ", + transformer_engine::to_string(qkv_layout), "."); } size_t d = input_QKV->data.shape[ndim - 1]; size_t t = 0; @@ -667,7 +724,8 @@ void nvte_fused_attn_bwd_qkvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { h = input_QKV->data.shape[ndim - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts, got ", + transformer_engine::to_string(qkv_layout), "."); } size_t d = input_QKV->data.shape[ndim - 1]; size_t t = 0; @@ -824,7 +882,8 @@ void nvte_fused_attn_fwd_kvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { h_kv = input_KV->data.shape[ndim_kv - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts, got ", + transformer_engine::to_string(qkv_layout), "."); } size_t t_q = 0; size_t t_kv = 0; @@ -981,7 +1040,8 @@ void nvte_fused_attn_bwd_kvpacked( } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { h_kv = input_KV->data.shape[ndim_kv - 3]; } else { - NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts!"); + NVTE_ERROR("nvte_fused_attn_fwd_kvpacked only supports HD_H2D and HD_2HD layouts, got ", + transformer_engine::to_string(qkv_layout), "."); } size_t t_q = 0; size_t t_kv = 0; diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 4ae0b467b5..7dd2cf1d13 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -232,7 +232,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported router probs dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16.");\ } #define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ @@ -255,7 +257,10 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Invalid type."); \ + NVTE_ERROR("Unsupported router index dtype ", \ + to_string(static_cast(dtype)), \ + ". Expected one of: Int32, Int64, BFloat16, " \ + "Float32."); \ } } // namespace transformer_engine #endif diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh index eb99edc4d3..16b37813f1 100644 --- a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -326,17 +326,20 @@ void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, // Check can implement the kernel. if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to implement CUTLASS Grouped GEMM with ", + num_gemms, " GEMMs"); } // Initialize the kernel. if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to initialize CUTLASS Grouped GEMM with ", + num_gemms, " GEMMs"); } // Execute the kernel in the current stream. if (gemm.run(stream) != cutlass::Status::kSuccess) { - NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + NVTE_ERROR("Failed to run CUTLASS Grouped GEMM with ", + num_gemms, " GEMMs"); } } diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 852b418b39..ba8b9930da 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -116,7 +116,8 @@ void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* beta_dptr, void* mean_dptr, void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Backward normalization should not call the forward execute function!"); + NVTE_ERROR("Backward normalization should not call the forward execute function. " + "Use the backward-specific execute overload instead."); } template @@ -165,7 +166,8 @@ void TeNormalizationPlan::execute(void* x_dptr, void* gamma void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Forward normalization should not call the backward execute function!"); + NVTE_ERROR("Forward normalization should not call the backward execute function. " + "Use the forward-specific execute overload instead."); } template <> diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 06971443dd..07d4a71a39 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -650,7 +650,7 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) { NVTEShape nvte_tensor_shape(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) { - NVTE_ERROR("Invalid tensor"); + NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_shape"); } // Determine tensor shape depending on tensor format @@ -662,7 +662,7 @@ NVTEShape nvte_tensor_shape(const NVTETensor tensor) { NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) { - NVTE_ERROR("Invalid tensor"); + NVTE_ERROR("Invalid tensor: received null pointer in nvte_tensor_columnwise_shape"); } const std::vector &shape = t->columnwise_data.shape; return nvte_make_shape(shape.data(), shape.size()); diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 0e286009a5..9c9c43b51a 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -463,7 +463,7 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size std::is_same_v) { dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; } else { - NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + NVTE_ERROR("Invalid output type for blockwise transpose (must be FP8: Float8E4M3 or Float8E5M2)."); } CUtensorMap tensor_map_output_trans{}; diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index e5d75e1501..66e3bb8784 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -165,13 +165,13 @@ def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): kv_max_seqlen = q_max_seqlen num_gqa_groups = attn_heads v_head_dim = q_head_dim - assert nqkv == 3 + assert nqkv == 3, f"Expected nqkv == 3 for qkvpacked layout, but got nqkv={nqkv} from q_aval.shape={q_aval.shape}" elif qkv_layout.is_kvpacked(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == v_head_dim - assert nkv == 2 + assert q_batch_shape == kv_batch_shape, f"Mismatched batch shapes for kvpacked layout: q_batch_shape={q_batch_shape}, kv_batch_shape={kv_batch_shape}" + assert q_head_dim == v_head_dim, f"Mismatched head dims for kvpacked layout: q_head_dim={q_head_dim}, v_head_dim={v_head_dim}" + assert nkv == 2, f"Expected nkv == 2 for kvpacked layout, but got nkv={nkv} from k_aval.shape={k_aval.shape}" elif qkv_layout.is_separate(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape @@ -244,9 +244,9 @@ def check_seed(self, seed, dropout_probability, is_training): ) seed = seed.astype(self.rng_state_dtype) - assert seed.dtype == self.rng_state_dtype + assert seed.dtype == self.rng_state_dtype, f"Expected seed.dtype={self.rng_state_dtype}, but got seed.dtype={seed.dtype}" # Backend takes an int64_t seed, so only the first two u32 elements are taken - assert seed.size >= self.seed_size + assert seed.size >= self.seed_size, f"Expected seed.size >= {self.seed_size}, but got seed.size={seed.size}" return seed @@ -363,7 +363,7 @@ def abstract( # 32-bit unsigned int to get the buffer size we need in the C++ kernel checker = _FusedAttnRNGStateChecker() seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype + assert seed_dtype == checker.rng_state_dtype, f"Expected seed_dtype={checker.rng_state_dtype}, but got seed_dtype={seed_dtype}" rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) @@ -408,11 +408,11 @@ def abstract( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - assert softmax_offset_aval.dtype == jnp.float32 + assert softmax_offset_aval.dtype == jnp.float32, f"Expected softmax_offset_aval.dtype=float32, but got {softmax_offset_aval.dtype}" if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: - assert softmax_offset_aval.shape == (1, attn_heads, 1, 1) + assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), f"Expected softmax_offset_aval.shape=(1, {attn_heads}, 1, 1) for {config.softmax_type}, but got {softmax_offset_aval.shape}" else: - assert softmax_offset_aval.shape == (0,) + assert softmax_offset_aval.shape == (0,), f"Expected softmax_offset_aval.shape=(0,) for VANILLA_SOFTMAX, but got {softmax_offset_aval.shape}" return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval @@ -533,7 +533,7 @@ def impl( _kv_segment_pos, config: _FusedAttnConfig, ): - assert FusedAttnFwdPrimitive.inner_primitive is not None + assert FusedAttnFwdPrimitive.inner_primitive is not None, "FusedAttnFwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -627,7 +627,7 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnFwdPrimitive.outer_primitive is not None + assert FusedAttnFwdPrimitive.outer_primitive is not None, "FusedAttnFwdPrimitive.outer_primitive has not been registered" q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim @@ -778,8 +778,8 @@ def abstract( v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype, f"Mismatched dtypes: q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}, doutput_dtype={doutput_dtype}" + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, f"Mismatched seqlen dtypes: q_seqlen_or_cu_seqlen_aval.dtype={q_seqlen_or_cu_seqlen_aval.dtype}, kv_seqlen_or_cu_seqlen_aval.dtype={kv_seqlen_or_cu_seqlen_aval.dtype}" ( batch_shape, @@ -983,7 +983,7 @@ def impl( _kv_segment_pos, config, ): - assert FusedAttnBwdPrimitive.inner_primitive is not None + assert FusedAttnBwdPrimitive.inner_primitive is not None, "FusedAttnBwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -1023,7 +1023,7 @@ def convert_to_2d(offsets, batch, max_seqlen): batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( q, k, v, config.qkv_layout ) - assert len(batch) == 1 + assert len(batch) == 1, f"Expected len(batch) == 1, but got len(batch)={len(batch)}, batch={batch}" kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 @@ -1082,7 +1082,7 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None + assert FusedAttnBwdPrimitive.outer_primitive is not None, "FusedAttnBwdPrimitive.outer_primitive has not been registered" q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim @@ -3396,7 +3396,7 @@ def fused_attn_fwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert bias is None, f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: @@ -3414,10 +3414,10 @@ def fused_attn_fwd( softmax_offset, (None, HEAD_AXES, None, None) ) else: - assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX + assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX, f"Expected VANILLA_SOFTMAX when softmax_offset is None and not OFF_BY_ONE_SOFTMAX, but got softmax_type={softmax_type}" softmax_offset = jnp.zeros(0, dtype=jnp.float32) else: - assert softmax_offset.dtype == jnp.float32 + assert softmax_offset.dtype == jnp.float32, f"Expected softmax_offset.dtype=float32, but got softmax_offset.dtype={softmax_offset.dtype}" # Shard by heads dimension if not VANILLA_SOFTMAX if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: softmax_offset = with_sharding_constraint_by_logical_axes( @@ -3556,7 +3556,7 @@ def fused_attn_bwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert bias is None, f"bias must be None when attn_bias_type is NO_BIAS, but got bias with type={type(bias)}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 71f133bfc4..7555f77718 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -166,8 +166,8 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ flatten_axis=flatten_axis, ) - assert not isinstance(lhs_q, ScaledTensor2x) - assert not isinstance(rhs_q, ScaledTensor2x) + assert not isinstance(lhs_q, ScaledTensor2x), f"Expected lhs_q to not be ScaledTensor2x after quantization, but got type={type(lhs_q)}" + assert not isinstance(rhs_q, ScaledTensor2x), f"Expected rhs_q to not be ScaledTensor2x after quantization, but got type={type(rhs_q)}" def has_rht_applied(q: AbstractBaseTensor) -> bool: return isinstance(q, ScaledTensor1x) and q.has_rht_applied @@ -529,8 +529,8 @@ def _dims_are_consecutive(dims): f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) - assert alpha.size == 1 and alpha.dtype == jnp.float32 - assert beta.size == 1 and beta.dtype == jnp.float32 + assert alpha.size == 1 and alpha.dtype == jnp.float32, f"Expected alpha to be a single float32 scalar, but got alpha.size={alpha.size}, alpha.dtype={alpha.dtype}" + assert beta.size == 1 and beta.dtype == jnp.float32, f"Expected beta to be a single float32 scalar, but got beta.size={beta.size}, beta.dtype={beta.dtype}" # Declare cuBLAS workspace workspace_size = get_cublas_workspace_size_bytes() @@ -809,7 +809,7 @@ def batcher( is_outer, ): del transpose_batch_sequence, sequence_dim, is_outer - assert GemmPrimitive.outer_primitive is not None + assert GemmPrimitive.outer_primitive is not None, "GemmPrimitive.outer_primitive has not been registered" lhs_bdims, _, rhs_bdims, *_ = batch_dims # Batched GEMM is not supported @@ -1321,7 +1321,7 @@ def _te_gemm( alpha = jnp.ones((1,), jnp.float32) beta = jnp.zeros((1,), jnp.float32) if scaling_mode.is_nvfp4_scaling: - assert lhs_amax is not None and rhs_amax is not None + assert lhs_amax is not None and rhs_amax is not None, "NVFP4 scaling requires non-None amax for both LHS and RHS operands" lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv @@ -1402,7 +1402,7 @@ def impl( group_sizes, num_gemms, ): - assert GroupedGemmCopySizesPrimitive.inner_primitive is not None + assert GroupedGemmCopySizesPrimitive.inner_primitive is not None, "GroupedGemmCopySizesPrimitive.inner_primitive has not been registered" out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( group_sizes, num_gemms=num_gemms, @@ -1555,7 +1555,7 @@ def impl( is_grouped_dense_wgrad, use_async_d2h_group_sizes, ): - assert GroupedGemmPrimitive.inner_primitive is not None + assert GroupedGemmPrimitive.inner_primitive is not None, "GroupedGemmPrimitive.inner_primitive has not been registered" (out, _) = GroupedGemmPrimitive.inner_primitive.bind( lhs_data, lhs_scale_inv, @@ -1693,7 +1693,7 @@ def _jax_scaled_matmul( lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype ) if lhs.scaling_mode.is_nvfp4_scaling: - assert lhs.amax is not None and rhs.amax is not None + assert lhs.amax is not None and rhs.amax is not None, "NVFP4 scaling requires non-None amax for both LHS and RHS operands" lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv @@ -1945,7 +1945,7 @@ def grouped_gemm( lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING elif isinstance(lhs, GroupedScaledTensor1x): - assert isinstance(rhs, GroupedScaledTensor1x) + assert isinstance(rhs, GroupedScaledTensor1x), f"Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but got type={type(rhs)}" out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape rhs_shape = rhs.original_shape @@ -1953,7 +1953,7 @@ def grouped_gemm( rhs_data = rhs.data lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv - assert lhs.scaling_mode == rhs.scaling_mode + assert lhs.scaling_mode == rhs.scaling_mode, f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}, rhs.scaling_mode={rhs.scaling_mode}" scaling_mode = lhs.scaling_mode else: raise TypeError("Unsupported lhs type object!") @@ -1990,8 +1990,8 @@ def grouped_gemm( and not isinstance(rhs, ScaledTensor) and quantizer_set != noop_quantizer_set ): - assert isinstance(quantizer_set.x, GroupedQuantizer) - assert type(quantizer_set.x) is type(quantizer_set.kernel) + assert isinstance(quantizer_set.x, GroupedQuantizer), f"Expected quantizer_set.x to be GroupedQuantizer, but got type={type(quantizer_set.x)}" + assert type(quantizer_set.x) is type(quantizer_set.kernel), f"Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" scaling_mode = quantizer_set.x.scaling_mode if ( quantizer_set.x.scaling_mode.is_tensor_scaling() @@ -2057,19 +2057,19 @@ def grouped_gemm( # Calling GroupedGEMM Custom Call K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - assert K_lhs == K_rhs + assert K_lhs == K_rhs, f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G if is_grouped_dense_wgrad: N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) else: - assert group_sizes.size == rhs_shape[0] + assert group_sizes.size == rhs_shape[0], f"Expected group_sizes.size == rhs_shape[0], but got group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" - assert group_offset.size == 1 + assert group_offset.size == 1, f"Expected group_offset.size == 1, but got group_offset.size={group_offset.size}" has_bias = bias is not None - assert not has_bias or bias.shape == (group_sizes.size, N) + assert not has_bias or bias.shape == (group_sizes.size, N), f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" bias = jnp.empty((), jnp.float32) if bias is None else bias (out,) = GroupedGemmPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 70fdf4c474..47e0f6166e 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -132,9 +132,9 @@ def abstract( ) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval is None or scale_aval.dtype == jnp.float32 - assert amax_aval is None or amax_aval.dtype == jnp.float32 + assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16], f"Unsupported x_dtype={x_dtype}, expected one of [float32, float16, bfloat16]" + assert scale_aval is None or scale_aval.dtype == jnp.float32, f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert amax_aval is None or amax_aval.dtype == jnp.float32, f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" assert ( scaling_mode != ScalingMode.MXFP8_1D_SCALING.value @@ -159,7 +159,7 @@ def abstract( mu_rsigama_dtype = jnp.float32 if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.size == beta_aval.size + assert gamma_aval.size == beta_aval.size, f"Expected gamma_aval.size == beta_aval.size, but got gamma_aval.size={gamma_aval.size}, beta_aval.size={beta_aval.size}" assert gamma_aval.dtype == beta_aval.dtype, ( f"gamma and beta should have the same dtype, but got {gamma_aval.dtype} and " f"{beta_aval.dtype}" @@ -265,18 +265,18 @@ def lowering( del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval is None or scale_aval.dtype == jnp.float32 - assert amax_aval is None or amax_aval.dtype == jnp.float32 + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16], f"Unsupported x_aval.dtype={x_aval.dtype}, expected one of [float32, float16, bfloat16]" + assert scale_aval is None or scale_aval.dtype == jnp.float32, f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert amax_aval is None or amax_aval.dtype == jnp.float32, f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.dtype == beta_aval.dtype + assert gamma_aval.dtype == beta_aval.dtype, f"Expected gamma and beta to have the same dtype, but got gamma_aval.dtype={gamma_aval.dtype}, beta_aval.dtype={beta_aval.dtype}" b_type = ir.RankedTensorType(beta.type) b_shape = b_type.shape - assert g_type == b_type - assert g_shape == b_shape + assert g_type == b_type, f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}, beta_type={b_type}" + assert g_shape == b_shape, f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}, beta_shape={b_shape}" sm_margin = get_forward_sm_margin() return ffi.ffi_lowering( @@ -321,7 +321,7 @@ def impl( to describe implementation """ del is_outer - assert NormFwdPrimitive.inner_primitive is not None + assert NormFwdPrimitive.inner_primitive is not None, "NormFwdPrimitive.inner_primitive has not been registered" ( out, colwise_out, @@ -391,7 +391,7 @@ def batcher( to describe batch rules for vmap """ check_valid_batch_dims(batch_dims) - assert NormFwdPrimitive.outer_primitive is not None + assert NormFwdPrimitive.outer_primitive is not None, "NormFwdPrimitive.outer_primitive has not been registered" x, scale, amax, gamma, beta = batched_args x_bdim, scale_bdim, _, _, _ = batch_dims @@ -706,13 +706,13 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_ w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) - assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype - assert dz_aval.shape == x_aval.shape + assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype, f"Expected dz_aval.dtype={w_dtype} (matching gamma dtype), but got dz_aval.dtype={dtypes.canonicalize_dtype(dz_aval.dtype)}" + assert dz_aval.shape == x_aval.shape, f"Expected dz_aval.shape == x_aval.shape, but got dz_aval.shape={dz_aval.shape}, x_aval.shape={x_aval.shape}" if norm_type == NVTE_Norm_Type.LayerNorm: mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype) - assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] - assert mu_dtype == rsigma_dtype == jnp.float32 + assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], f"Expected mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], but got mu_aval.shape={mu_aval.shape}, rsigma_aval.shape={rsigma_aval.shape}, x_aval.shape[:-1]={x_aval.shape[:-1]}" + assert mu_dtype == rsigma_dtype == jnp.float32, f"Expected mu_dtype == rsigma_dtype == float32, but got mu_dtype={mu_dtype}, rsigma_dtype={rsigma_dtype}" dx_aval = dz_aval dgamma_aval = dbeta_aval = gamma_aval @@ -756,8 +756,8 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): g_shape = g_type.shape b_type = ir.RankedTensorType(gamma.type) b_shape = b_type.shape - assert g_type == b_type - assert g_shape == b_shape + assert g_type == b_type, f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}, beta_type={b_type}" + assert g_shape == b_shape, f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}, beta_shape={b_shape}" sm_margin = get_backward_sm_margin() return ffi.ffi_lowering(NormBwdPrimitive.name)( @@ -774,7 +774,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): @staticmethod def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): - assert NormBwdPrimitive.inner_primitive is not None + assert NormBwdPrimitive.inner_primitive is not None, "NormBwdPrimitive.inner_primitive has not been registered" dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma ) @@ -783,7 +783,7 @@ def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): @staticmethod def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): check_valid_batch_dims(batch_dims) - assert NormBwdPrimitive.outer_primitive is not None + assert NormBwdPrimitive.outer_primitive is not None, "NormBwdPrimitive.outer_primitive has not been registered" dz, x, mu, rsigma, gamma = batched_args _, x_bdim, _, _, gamma_bdim = batch_dims diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index ad5a60e4c2..3e1828c681 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -182,7 +182,9 @@ def __call__( is_gqa = h_q != h_kv if is_gqa: - assert (h_q % h_kv == 0) and (h_q >= h_kv) + assert (h_q % h_kv == 0) and (h_q >= h_kv), ( + f"num_query_heads ({h_q}) must be divisible by and >= num_kv_heads ({h_kv})" + ) group_size = h_q // h_kv grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) @@ -428,7 +430,7 @@ def __call__( if self.transpose_batch_sequence: x = x.transpose([1, 0, 2, 3]) - assert x.dtype == query.dtype + assert x.dtype == query.dtype, f"output dtype {x.dtype} does not match query dtype {query.dtype}" return x @@ -713,9 +715,9 @@ def __call__( del self.attn_bias_type, self.attn_mask_type, self.qkv_layout if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None + assert bias is None, f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" else: - assert bias is not None + assert bias is not None, f"bias must not be None when attn_bias_type is {attn_bias_type}" bias = bias.astype(input_dtype) self._assert_dtypes(query, key, value, qkv_layout) @@ -823,10 +825,14 @@ def __call__( key, value = jnp.split(key, [1], axis=-3) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) else: - assert qkv_layout.is_separate() + assert qkv_layout.is_separate(), ( + f"Expected separate qkv_layout, but got {qkv_layout}" + ) assert sequence_descriptor is None or isinstance( sequence_descriptor, (jnp.ndarray, np.ndarray) + ), ( + f"sequence_descriptor must be None or ndarray, but got {type(sequence_descriptor)}" ) x = _UnfusedDotProductAttention( @@ -994,7 +1000,7 @@ def _canonicalize_lora_scope(scope): SCOPE_EX_QKV_PROJ, SCOPE_EX_OUTPUT_PROJ, SCOPE_EX_MLP, - ] + ], f"Unsupported LoRA scope: {scope}" lora_scope = LoRAScope() @@ -1307,8 +1313,8 @@ def query_init(*args): return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) def qkv_init(key, shape, dtype): - assert len(shape) == 3 - assert shape[-2] == 3 + assert len(shape) == 3, f"qkv_init expects 3D shape, but got {len(shape)}D shape {shape}" + assert shape[-2] == 3, f"qkv_init expects shape[-2] == 3, but got shape={shape}" q_key, k_key, v_key = jax_random.split(key, num=3) @@ -1323,8 +1329,8 @@ def qkv_init(key, shape, dtype): return jnp.stack([q_kernel, k_kernel, v_kernel], axis=-2, dtype=dtype) def kv_init(key, shape, dtype): - assert len(shape) == 3 - assert shape[-2] == 2 + assert len(shape) == 3, f"kv_init expects 3D shape, but got {len(shape)}D shape {shape}" + assert shape[-2] == 2, f"kv_init expects shape[-2] == 2, but got shape={shape}" k_key, v_key = jax_random.split(key) @@ -1415,7 +1421,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): )(inputs_q) if is_self_attn: - assert ln_out is not None + assert ln_out is not None, "ln_out must not be None for self-attention" inputs_kv = ln_out kv_proj = DenseGeneral( @@ -1475,7 +1481,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): )(inputs_q) if is_self_attn: - assert ln_out is not None + assert ln_out is not None, "ln_out must not be None for self-attention" inputs_kv = ln_out query = query.astype(input_dtype) @@ -1494,7 +1500,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): elif qkv_layout == QKVLayout.BSHD_BS2HD: key, value = jnp.split(kv_proj, [1], axis=-2) else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD, ( + f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" + ) # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact) query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) @@ -1520,7 +1528,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) if decode: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD, ( + f"decode mode requires QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" + ) is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) @@ -1588,7 +1598,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) dpa_args = [query, kv_proj, None] else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD, ( + f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" + ) query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) @@ -2101,7 +2113,7 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): l = inputs.shape[sequence_dim] attn_bias = rel_emb(l, l, False) - assert inputs.ndim == 3 + assert inputs.ndim == 3, f"inputs must be 3D (batch, sequence, hidden), but got {inputs.ndim}D" # Make name be the exactly same as T5X, since names would affect # RNGKey during init and apply. Myabe no need in the feature. @@ -2151,10 +2163,15 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) def hidden_dropout(x, deterministic): - assert isinstance(self.hidden_dropout_dims, Sequence) + assert isinstance(self.hidden_dropout_dims, Sequence), ( + f"hidden_dropout_dims must be a Sequence, but got {type(self.hidden_dropout_dims)}" + ) x_shape_len = len(x.shape) for dims in self.hidden_dropout_dims: - assert -x_shape_len <= dims < x_shape_len + assert -x_shape_len <= dims < x_shape_len, ( + f"hidden_dropout_dims value {dims} is out of range " + f"[{-x_shape_len}, {x_shape_len}) for input with {x_shape_len} dimensions" + ) return nn.Dropout( rate=self.hidden_dropout, @@ -2179,7 +2196,9 @@ def hidden_dropout(x, deterministic): )(x, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ln_out is not None, ( + "ln_out must not be None when apply_residual_connection_post_layernorm is True" + ) residual = ln_out x = x + residual @@ -2239,7 +2258,9 @@ def hidden_dropout(x, deterministic): y = hidden_dropout(y, deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ln_out is not None, ( + "ln_out must not be None when apply_residual_connection_post_layernorm is True" + ) residual = ln_out mlp_input = y + residual @@ -2284,7 +2305,9 @@ def hidden_dropout(x, deterministic): )(mlp_input, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None + assert ln_out is not None, ( + "ln_out must not be None when apply_residual_connection_post_layernorm is True" + ) residual = ln_out z = with_sharding_constraint_by_logical_axes( diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 3f3f3802db..83a8544256 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -31,7 +31,11 @@ def canonicalize_norm_type(x): Canonicalized normalization type string """ canonicalized = x.lower().strip().replace("-", "").replace("_", "") - assert canonicalized in ["layernorm", "rmsnorm"] + if canonicalized not in ["layernorm", "rmsnorm"]: + raise ValueError( + f"Unsupported normalization type '{x}' (canonicalized: '{canonicalized}'). " + f"Valid options are: 'layernorm', 'rmsnorm'." + ) return canonicalized diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..d78d2ab1f0 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -270,14 +270,20 @@ def fused_attn_fwd( attn_scale = 1.0 / math.sqrt(d) if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert attn_bias is not None, ( + f"attn_bias tensor cannot be None when attn_bias_type={attn_bias_type!r}." + ) + assert attn_bias.dtype == q.dtype, ( + f"attn_bias tensor must have the same dtype as q and kv: " + f"attn_bias.dtype={attn_bias.dtype} but q.dtype={q.dtype}." + ) + + assert fused_attention_backend != FusedAttnBackend["No_Backend"], ( + f"Fused attention does not support this input combination:" + f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," + f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," + f" q.dtype={q.dtype}, backend={fused_attention_backend}." + ) # BF16/FP16 fused attention API from fmha_v1 apex if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: @@ -293,12 +299,14 @@ def fused_attn_fwd( max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention." - assert ( - o_quantizer is not None - ), "o_quantizer is required as an input for FP8 fused attention." + assert s_quantizer is not None, ( + f"s_quantizer is required for FP8 fused attention forward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + assert o_quantizer is not None, ( + f"o_quantizer is required for FP8 fused attention forward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -487,28 +495,38 @@ def fused_attn_bwd( d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." + assert fused_attention_backend != FusedAttnBackend["No_Backend"], ( + f"Fused attention backward does not support this input combination:" + f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," + f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," + f" q.dtype={q.dtype}, backend={fused_attention_backend}." + ) if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." + assert len(aux_ctx_tensors) >= 1, ( + f"aux_ctx_tensors must contain rng_state as its last element," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" for backend={fused_attention_backend}." + ) if fused_attention_backend == FusedAttnBackend["FP8"]: - assert ( - s_quantizer is not None - ), "s_quantizer is required as an input for FP8 fused attention backward." - assert ( - dp_quantizer is not None - ), "dp_quantizer is required as an input for FP8 fused attention backward." - assert ( - dqkv_dtype is not None - ), "dqkv_dtype is required as an input for FP8 fused attention backward." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + assert s_quantizer is not None, ( + f"s_quantizer is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + assert dp_quantizer is not None, ( + f"dp_quantizer is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + assert dqkv_dtype is not None, ( + f"dqkv_dtype is required for FP8 fused attention backward" + f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." + ) + assert len(aux_ctx_tensors) == 3, ( + f"aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" (backend={fused_attention_backend})." + ) output_tensors = tex.fused_attn_bwd( max_seqlen_q, diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 05219b7b18..164124ae29 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -124,7 +124,10 @@ def tensor_group_process_after_reload(tensor_group: TensorGroup): """ Call for a tensor group, just after reload logic. """ - assert tensor_group.aux is not None + assert tensor_group.aux is not None, ( + "TensorGroup.aux must be set before post-reload processing, " + f"but got aux=None for tensor_group with {len(tensor_group.tensor_list)} tensors" + ) tensor_group = TensorGroupProcessor._restore_tensor_duplicates(tensor_group) tensor_group = TensorGroupProcessor._switch_to_views(tensor_group) return tensor_group @@ -271,7 +274,11 @@ def start_offload(self): ) for tensor_id, tensor in enumerate(self.fwd_gpu_tensor_group.tensor_list): - assert tensor.is_contiguous() + assert tensor.is_contiguous(), ( + f"Tensor at index {tensor_id} must be contiguous for CPU offloading, " + f"but got non-contiguous tensor with shape={tensor.shape}, " + f"stride={tensor.stride()}, dtype={tensor.dtype}" + ) # Wait for the moment the tensor is ready to be offloaded. self.offload_stream.wait_event(self.fwd_gpu_tensor_group.events[tensor_id]) # type: ignore[arg-type] @@ -420,7 +427,10 @@ def pop_tensor( return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] # 4. the layer was offloaded - assert self.state == "reload_started" + assert self.state == "reload_started", ( + f"Expected state='reload_started' when popping an offloaded tensor, " + f"but got state='{self.state}' for tensor={tensor_or_tensor_id}" + ) # wait for the tensor to be reloaded torch.cuda.current_stream().wait_event( self.bwd_gpu_tensor_group.events[tensor_or_tensor_id] @@ -882,8 +892,16 @@ def synchronization_function(self, tensor): """ This function is used to catch the backward pass of the model. """ - assert tensor.requires_grad is True - assert self.current_layer is not None + assert tensor.requires_grad is True, ( + f"Tensor passed to synchronization_function must require grad to " + f"register backward hooks, but got requires_grad=False for tensor " + f"with shape={tensor.shape}, dtype={tensor.dtype}" + ) + assert self.current_layer is not None, ( + "synchronization_function called but no layer has been set via __enter__. " + f"inside_context={self.inside_context}, " + f"offload_synchronizer num_layers={self.offload_synchronizer.num_layers}" + ) cur_layer = self.current_layer assert ( self.inside_context is False diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 645dbb48d2..e6248f2dcc 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -272,7 +272,9 @@ at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType } else if (size == 1) { return at::empty({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); } - NVTE_CHECK(false, "Should never reach here! func: allocateSpace"); + NVTE_ERROR("Unsupported tensor allocation: ndim=", size, + ", init_to_zeros=", init_to_zeros, + ". Only 1D and 2D tensors are supported."); } at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..7a9ac63e4e 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1475,7 +1475,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // Compute amax. if (this->with_rht) { if (input.dtype() != DType::kBFloat16) { - NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); + NVTE_ERROR("RHT is only supported for bfloat16 input, got dtype enum value ", + static_cast(input.dtype())); } if (this->with_post_rht_amax) { // We need: @@ -1487,7 +1488,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou }); } else { // raise error since it's not supported yet - NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); + NVTE_ERROR("Pre-RHT amax is not supported yet. " + "Use with_post_rht_amax=true instead."); } } else { // Without RHT if (compute_amax) { diff --git a/transformer_engine/pytorch/custom_recipes/gemm.py b/transformer_engine/pytorch/custom_recipes/gemm.py index 8f853ff093..955640ec15 100644 --- a/transformer_engine/pytorch/custom_recipes/gemm.py +++ b/transformer_engine/pytorch/custom_recipes/gemm.py @@ -68,11 +68,11 @@ def custom_gemm( if gemm_type == GEMMType.FPROP: qx, sx = A.data, A.scale qw, sw = B.data, B.scale - assert qx is not None - assert sx is not None - assert qw is not None - assert sw is not None - assert A.original_shape is not None + assert qx is not None, "FPROP GEMM: quantized activation data (A.data) is None" + assert sx is not None, "FPROP GEMM: activation scale (A.scale) is None" + assert qw is not None, "FPROP GEMM: quantized weight data (B.data) is None" + assert sw is not None, "FPROP GEMM: weight scale (B.scale) is None" + assert A.original_shape is not None, "FPROP GEMM: A.original_shape is None, cannot determine output shape" # Call quantizer's qgemm method result = quantizer.qgemm( @@ -95,10 +95,10 @@ def custom_gemm( elif gemm_type == GEMMType.DGRAD: qdy, sdy = A.data, A.scale qw_t, sw_t = B.data_t, B.scale_t - assert qdy is not None - assert sdy is not None - assert qw_t is not None - assert sw_t is not None + assert qdy is not None, "DGRAD GEMM: quantized gradient data (A.data) is None" + assert sdy is not None, "DGRAD GEMM: gradient scale (A.scale) is None" + assert qw_t is not None, "DGRAD GEMM: transposed quantized weight data (B.data_t) is None" + assert sw_t is not None, "DGRAD GEMM: transposed weight scale (B.scale_t) is None" result = quantizer.qgemm( qdy, @@ -115,10 +115,10 @@ def custom_gemm( elif gemm_type == GEMMType.WGRAD: qdy_t, sdy_t = A.data_t, A.scale_t qx_t, sx_t = B.data_t, B.scale_t - assert qdy_t is not None - assert sdy_t is not None - assert qx_t is not None - assert sx_t is not None + assert qdy_t is not None, "WGRAD GEMM: transposed quantized gradient data (A.data_t) is None" + assert sdy_t is not None, "WGRAD GEMM: transposed gradient scale (A.scale_t) is None" + assert qx_t is not None, "WGRAD GEMM: transposed quantized activation data (B.data_t) is None" + assert sx_t is not None, "WGRAD GEMM: transposed activation scale (B.scale_t) is None" result = quantizer.qgemm( qdy_t, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index d00d0c8b94..7e07355bcd 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -325,7 +325,9 @@ def size(self, *args, **kwargs): # pylint: disable=unused-argument the second dimension by half. This method returns the logical shape that users expect, not the internal packed storage shape. """ - assert self.original_shape is not None + assert self.original_shape is not None, ( + "NVFP4TensorRef.size() called but original_shape has not been set" + ) return torch.Size(self.original_shape) @@ -446,7 +448,9 @@ def _quantize_blockwise_reference( eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.ndim == 2 + assert x.ndim == 2, ( + f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape {x.shape}" + ) using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 m, n = x.shape # Compute vec_max based on the original x (before reshape) @@ -766,7 +770,9 @@ def is_data_t_transposed_in_memory(self) -> bool: TODO(etsykunov): Confirm docstring is correct. """ - raise NotImplementedError("Not implemented yet") + raise NotImplementedError( + "NVFP4QuantizerRef.is_data_t_transposed_in_memory is not implemented for FP4 quantization" + ) def qgemm( self, @@ -784,7 +790,7 @@ def qgemm( qresult_w: QuantizedTensorStorage | None = None, ) -> torch.Tensor: """Python implementation of microblock FP4 GEMM.""" - assert bias is None, "Bias is implemented for FP4 GEMM." + assert bias is None, "Bias is not supported in NVFP4QuantizerRef.qgemm" high_precision_x = cast_from_fp4x2(qx, out_dtype) high_precision_w = cast_from_fp4x2(qw, out_dtype) @@ -814,11 +820,19 @@ def qgemm( else: - assert qresult_x is not None - assert qresult_w is not None + assert qresult_x is not None, ( + "qresult_x is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + ) + assert qresult_w is not None, ( + "qresult_w is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + ) - assert qresult_x.global_amax_row is not None - assert qresult_w.global_amax_col is not None + assert qresult_x.global_amax_row is not None, ( + "qresult_x.global_amax_row must be set for non-pow_2_scales NVFP4 GEMM" + ) + assert qresult_w.global_amax_col is not None, ( + "qresult_w.global_amax_col must be set for non-pow_2_scales NVFP4 GEMM" + ) sx = sx.to(torch.float32) sw = sw.to(torch.float32) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..2e7c53b9d3 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -152,7 +152,11 @@ def set_tensor_model_parallel_attributes( ) -> None: """set attributes needed for TP""" for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - assert not hasattr(tensor, attribute) + if hasattr(tensor, attribute): + raise RuntimeError( + f"Tensor already has attribute '{attribute}' set. Cannot set " + f"tensor model parallel attributes on a tensor that already has them." + ) # Set the attributes. setattr(tensor, "tensor_model_parallel", is_parallel) setattr(tensor, "partition_dim", dim) @@ -170,7 +174,11 @@ def get_distributed_world_size(group: Optional[dist_group_type] = None) -> int: @lru_cache def get_distributed_rank(group: Optional[dist_group_type] = None) -> int: """Return my rank for the distributed group.""" - assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Call torch.distributed.init_process_group() " + "before calling get_distributed_rank()." + ) return torch.distributed.get_rank(group=group) @@ -743,7 +751,12 @@ def checkpoint( # If saved activations need to be distributed but there is no process group, # default to the world group. if distribute_saved_activations: - assert torch.distributed.is_initialized(), "torch.distributed is not initialized." + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed is not initialized. Call " + "torch.distributed.init_process_group() before using " + "distribute_saved_activations=True." + ) tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group return _CheckpointFunction.apply( @@ -917,9 +930,12 @@ def reduce_scatter_along_first_dim( return inp, None dim_size = list(inp.size()) - assert ( - dim_size[0] % world_size == 0 - ), "First dimension of the tensor should be divisible by tensor parallel size" + if dim_size[0] % world_size != 0: + raise ValueError( + f"First dimension of the tensor should be divisible by tensor parallel size, " + f"but got dim_size[0]={dim_size[0]} and world_size={world_size} " + f"(remainder={dim_size[0] % world_size})." + ) dim_size[0] = dim_size[0] // world_size @@ -984,7 +1000,11 @@ def _all_gather_fp8( # Note: We cannot directly all-gather the transposed FP8 tensor, # so temporarily modify quantizer to avoid creating FP8 transpose. if not isinstance(inp, Float8TensorStorage): - assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + if not isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + raise TypeError( + f"Expected quantizer to be Float8Quantizer or Float8CurrentScalingQuantizer " + f"when input is not Float8TensorStorage, but got {type(quantizer).__name__}." + ) # we cannot directly gather the transposed fp8 tensor # so we need to disable columnwise usage for the quantizer # and then set it back to the original value after quantizing @@ -1231,10 +1251,18 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): """ shape = tensor.shape - assert len(shape) >= 2, "Wrong number of dimensions for fixing interleave." + if len(shape) < 2: + raise ValueError( + f"Wrong number of dimensions for fixing interleave: got {len(shape)}, " + f"expected at least 2 (shape={shape})." + ) first_dim = shape[0] flattened_trailing = math.prod(shape[1:]) - assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." + if first_dim % world_size != 0: + raise ValueError( + f"Wrong dimensions for fixing interleave: first_dim={first_dim} is not divisible " + f"by world_size={world_size} (remainder={first_dim % world_size})." + ) tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing) tensor = tex.swap_first_dims(tensor, out=None) return tensor.reshape(first_dim // world_size, flattened_trailing * world_size) @@ -1324,7 +1352,11 @@ def _all_gather_nvfp4( f"found {inp.__class__.__name__})" ) - assert in_shape is not None or in_shape_t is not None, "No data found." + if in_shape is None and in_shape_t is None: + raise ValueError( + "No data found: both in_shape and in_shape_t are None. " + "Input tensor must have rowwise or columnwise data." + ) world_size = get_distributed_world_size(process_group) @@ -1374,7 +1406,11 @@ def _all_gather_nvfp4( if quantizer.rowwise_usage: # Remove padding from NVFP4 scale-inverses - assert in_shape is not None, "Shape not found." + if in_shape is None: + raise RuntimeError( + "Shape not found: in_shape is None but rowwise_usage is True. " + "Input tensor must have rowwise data for NVFP4 rowwise gathering." + ) in_scale_inv = inp._rowwise_scale_inv out_scale_inv = out._rowwise_scale_inv flattened_in_shape0 = math.prod(in_shape[:-1]) @@ -1672,7 +1708,11 @@ def gather_along_first_dim( # MXFP8 case if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer): - assert isinstance(quantizer, MXFP8Quantizer) + if not isinstance(quantizer, MXFP8Quantizer): + raise TypeError( + f"Expected MXFP8Quantizer for MXFP8 all-gather, " + f"but got {type(quantizer).__name__}." + ) return _all_gather_mxfp8( inp, process_group, @@ -1683,7 +1723,11 @@ def gather_along_first_dim( # NVFP4 case if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer): - assert isinstance(quantizer, NVFP4Quantizer) + if not isinstance(quantizer, NVFP4Quantizer): + raise TypeError( + f"Expected NVFP4Quantizer for NVFP4 all-gather, " + f"but got {type(quantizer).__name__}." + ) return _all_gather_nvfp4( inp, process_group, @@ -1826,8 +1870,15 @@ def symmetric_all_reduce( - The second element is the async work handle if async_op=True, otherwise None. """ - assert async_op is False, "Async symmetric ops no supported yet" - assert HAS_TORCH_SYMMETRIC, "Could not import symetric memory from torch" + if async_op is not False: + raise RuntimeError( + f"Async symmetric ops are not supported yet, but async_op={async_op!r} was passed." + ) + if not HAS_TORCH_SYMMETRIC: + raise RuntimeError( + "Could not import symmetric memory from torch. " + "Please ensure torch.distributed._symmetric_memory is available." + ) if get_distributed_world_size(tp_group) == 1: return inp, None @@ -1960,10 +2011,19 @@ def _fsdp_gather_tensors( *tensors: torch.Tensor, ): if fsdp_group is not None: - assert len(shapes) == len(tensors), "Number of tensors and tensor shapes must be equal." + if len(shapes) != len(tensors): + raise ValueError( + f"Number of tensors and tensor shapes must be equal, " + f"but got {len(shapes)} shapes and {len(tensors)} tensors." + ) for s, t in zip(shapes, tensors): if isinstance(t, torch.Tensor): - assert s is not None, "Internal TE error." + if s is None: + raise RuntimeError( + "Internal TE error: shape is None for a non-None tensor in " + "post_optimizer_step_fwd_amax_reduction. " + f"Tensor type: {type(t).__name__}, tensor shape: {t.shape}." + ) targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] for target in targets: safely_set_viewless_tensor_data( @@ -2011,17 +2071,25 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: fsdp_root : torch.nn.Module FSDP-wrapped root module that may contain FSDP-wrapped TE modules. """ - assert isinstance(fsdp_root, FSDP), "Root module must be FSDP-wrapped." + if not isinstance(fsdp_root, FSDP): + raise TypeError( + f"Root module must be FSDP-wrapped, but got {type(fsdp_root).__name__}." + ) # If the root module is a TE module, inject FSDP information into it if _is_te_module(fsdp_root.module): if hasattr(fsdp_root, "primary_weights_in_fp8"): - assert not fsdp_root.primary_weights_in_fp8, ( - "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.quantized_model_init(...) context." - ) + if fsdp_root.primary_weights_in_fp8: + raise RuntimeError( + "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " + "Please initialize your model without the te.quantized_model_init(...) context." + ) root_state = _get_module_fsdp_state(fsdp_root) - assert root_state is not None, "Root module does not have a valid _FSDPState." + if root_state is None: + raise RuntimeError( + f"Root module ({type(fsdp_root.module).__name__}) does not have a valid " + f"_FSDPState. Ensure the module is properly wrapped with FSDP." + ) fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) # Iterate through all FSDP-wrapped submodules and inject FSDP information into TE modules @@ -2029,10 +2097,12 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: for state, fsdp_module in zip(fsdp_states, fsdp_modules): if _is_te_module(fsdp_module.module): if hasattr(fsdp_module.module, "primary_weights_in_fp8"): - assert not fsdp_module.module.primary_weights_in_fp8, ( - "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.quantized_model_init(...) context." - ) + if fsdp_module.module.primary_weights_in_fp8: + raise RuntimeError( + f"TE module '{type(fsdp_module.module).__name__}' with primary weights " + "in FP8 cannot be FSDP-wrapped. Please initialize your model without " + "the te.quantized_model_init(...) context." + ) fsdp_module.module.fast_setattr("fsdp_group", state.process_group) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f4b1fb23ae..fb7b695854 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -146,8 +146,14 @@ def _make_graphed_callables( _order_without_wgrad = None delay_wgrad_compute = False if _order is None: - assert len(sample_args) == len(callables) - assert len(sample_kwargs) == len(callables) + assert len(sample_args) == len(callables), ( + f"Expected sample_args to have the same length as callables, " + f"but got {len(sample_args)} sample_args for {len(callables)} callables" + ) + assert len(sample_kwargs) == len(callables), ( + f"Expected sample_kwargs to have the same length as callables, " + f"but got {len(sample_kwargs)} sample_kwargs for {len(callables)} callables" + ) else: # Custom logic for interleaved pipeline parallelism # Note: This is tightly coupled with the Megatron-core @@ -171,7 +177,12 @@ def _make_graphed_callables( _order_without_wgrad.append(c_id) num_model_chunks = max(_order_without_wgrad) num_microbatches = len(_order_without_wgrad) // num_model_chunks // 2 - assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad) + assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad), ( + f"Pipeline-parallel order dimension mismatch: " + f"num_model_chunks ({num_model_chunks}) * num_microbatches ({num_microbatches}) * 2 " + f"= {num_model_chunks * num_microbatches * 2}, " + f"but len(_order_without_wgrad) = {len(_order_without_wgrad)}" + ) # When delay_wgrad_compute is enabled, each layer is treated as a model chunk, which # allows for fine-grained graph capture order. @@ -220,7 +231,11 @@ def _make_graphed_callables( num_layers = _num_layers_per_chunk[m_chunk] _prefix_num_layers.append(_prefix_num_layers[-1] + num_layers) - assert len(sample_kwargs) == len(sample_args) + assert len(sample_kwargs) == len(sample_args), ( + f"Pipeline-parallel schedule requires sample_kwargs and sample_args to have " + f"the same length, but got {len(sample_kwargs)} sample_kwargs " + f"for {len(sample_args)} sample_args" + ) # Check reuse graph conditions and reorganize sample_args and sample_kwargs. # Note: When capturing a graph, we hold onto the args and kwargs so we have static buffers @@ -352,7 +367,11 @@ def _make_graphed_callables( ) else () ) - assert len(per_callable_module_params) == len(flatten_sample_args) + assert len(per_callable_module_params) == len(flatten_sample_args), ( + f"Pipeline-parallel dimension mismatch: " + f"per_callable_module_params has {len(per_callable_module_params)} entries, " + f"but flatten_sample_args has {len(flatten_sample_args)} entries" + ) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] for i in range(len(flatten_sample_args)) @@ -800,7 +819,9 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Replay forward graph fwd_graph.replay() - assert isinstance(static_outputs, tuple) + assert isinstance(static_outputs, tuple), ( + f"Expected static_outputs to be a tuple, but got {type(static_outputs)}" + ) return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @@ -809,7 +830,11 @@ def backward(ctx, *grads): # pylint: disable=missing-function-docstring # Replay backward graph - assert len(grads) == len(static_grad_outputs) + assert len(grads) == len(static_grad_outputs), ( + f"Backward graph grad dimension mismatch: " + f"received {len(grads)} grads, " + f"but expected {len(static_grad_outputs)} static_grad_outputs" + ) for g, grad in zip(static_grad_outputs, grads): if g is not None: # don't copy if autograd gods have been kind and the @@ -823,7 +848,9 @@ def backward(ctx, *grads): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) # Input args that didn't require grad expect a None gradient. - assert isinstance(static_grad_inputs, tuple) + assert isinstance(static_grad_inputs, tuple), ( + f"Expected static_grad_inputs to be a tuple, but got {type(static_grad_inputs)}" + ) return (None,) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..8ed7d14aec 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -80,7 +80,8 @@ class UserBufferQuantizationMode(Enum): def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor: """Returns a dummy tensor of given shape.""" - assert len(shape) == 2 + if len(shape) != 2: + raise ValueError(f"Expected 2D shape, got {len(shape)}D: {shape}") global _dummy_wgrads if (shape[0], shape[1], dtype) not in _dummy_wgrads: _dummy_wgrads[(shape[0], shape[1], dtype)] = torch.empty( @@ -156,10 +157,11 @@ def initialize_ub( which also requires ``MPI_HOME=/path/to/mpi/root`` to be set at compile time. """ if not tex.device_supports_multicast(): - assert bool(int(os.getenv("UB_SKIPMC", "0"))), ( - "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap with " - + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." - ) + if not bool(int(os.getenv("UB_SKIPMC", "0"))): + raise RuntimeError( + "CUDA device, driver and/or toolkit version does not support comm+GEMM overlap " + "with CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." + ) if not quantization_modes: warnings.warn( @@ -171,34 +173,51 @@ def initialize_ub( UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE ] else: - assert isinstance(quantization_modes, list), "quantization_modes must be a list" - assert all( - isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes - ), "quantization_modes must be a list of UserBufferQuantizationMode" + if not isinstance(quantization_modes, list): + raise TypeError( + f"quantization_modes must be a list, got {type(quantization_modes).__name__}" + ) + invalid_modes = [ + mode for mode in quantization_modes + if not isinstance(mode, UserBufferQuantizationMode) + ] + if invalid_modes: + raise TypeError( + f"quantization_modes must be a list of UserBufferQuantizationMode, " + f"got invalid entries: {invalid_modes}" + ) if isinstance(ub_cfgs, dict) or ub_cfgs is None: ub_cfgs = [ub_cfgs] * len(quantization_modes) else: - assert len(ub_cfgs) == len( - quantization_modes - ), "Number of ub_cfgs settings must match number of quantization configurations" + if len(ub_cfgs) != len(quantization_modes): + raise ValueError( + f"Number of ub_cfgs settings ({len(ub_cfgs)}) must match number of " + f"quantization configurations ({len(quantization_modes)})" + ) global _ub_communicators - assert _ub_communicators is None, "UB communicators are already initialized." + if _ub_communicators is not None: + raise RuntimeError("UB communicators are already initialized.") _ub_communicators = {} if tex.ubuf_built_with_mpi(): # We're bootstrapping with direct calls to MPI in Userbuffers code so we need to force # an MPI_Init() here by creating a new MPI process group... - assert torch.distributed.is_mpi_available() + if not torch.distributed.is_mpi_available(): + raise RuntimeError( + "MPI backend is not available in torch.distributed but is required " + "when Userbuffers is built with MPI support" + ) _ = torch.distributed.new_group(backend="mpi") helper = tex.CommOverlapHelper() else: # Bootstrapping with torch.distributed API, so check backend and construct # intra/inter-node process groups... - assert ( - torch.distributed.is_initialized() - ), "torch.distributed must be initialized before Userbuffers" + if not torch.distributed.is_initialized(): + raise RuntimeError( + "torch.distributed must be initialized before using Userbuffers" + ) if bootstrap_backend is None: bootstrap_backend = "nccl" if torch.distributed.is_mpi_available(): @@ -206,15 +225,16 @@ def initialize_ub( elif torch.distributed.is_gloo_available(): bootstrap_backend = "gloo" else: - assert bootstrap_backend in [ - "gloo", - "mpi", - "nccl", - ], "Invalid torch.distributed backend for bootstrapping Userbuffers!" - assert torch.distributed.is_backend_available(bootstrap_backend), ( - f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " - f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." - ) + if bootstrap_backend not in ["gloo", "mpi", "nccl"]: + raise ValueError( + f"Invalid torch.distributed backend '{bootstrap_backend}' for bootstrapping " + f"Userbuffers. Must be one of: 'gloo', 'mpi', 'nccl'" + ) + if not torch.distributed.is_backend_available(bootstrap_backend): + raise RuntimeError( + f"PyTorch must be compiled with '{bootstrap_backend}' support in order to " + f"bootstrap Userbuffers with '{bootstrap_backend}' collectives." + ) world_group = torch.distributed.new_group(backend=bootstrap_backend) world_rank = torch.distributed.get_rank(world_group) @@ -333,9 +353,11 @@ def add_ub( warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) - assert ( - quantization_mode == UserBufferQuantizationMode.FP8 - ), "Atomic GEMM overlap supported only for FP8 GEMM." + if quantization_mode != UserBufferQuantizationMode.FP8: + raise ValueError( + f"Atomic GEMM overlap supported only for FP8 GEMM, " + f"got quantization_mode={quantization_mode}" + ) if method in ("bulk", "external"): warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." @@ -360,20 +382,24 @@ def add_ub( "for functionality." ) if name in layers_atomic_ring_exchange: - assert atomic_gemm and method == "ring_exchange", assert_message + if not (atomic_gemm and method == "ring_exchange"): + raise ValueError(assert_message) else: if atomic_gemm and method == "ring_exchange": - assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + if rs_ag_pairs[name] not in layers_atomic_ring_exchange: + raise ValueError(assert_message) if name in external_gemm_to_overlap: - assert method == "external", ( - f"At {name}, `external` overlap method is specified, but the selected method is" - f" {method}" - ) - assert external_gemm_to_overlap[name] in methods["ring_exchange"], ( - f"At {name}, `external` overlap method is specified, but the external gemm" - f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" - ) + if method != "external": + raise ValueError( + f"At {name}, `external` overlap method is specified, but the selected method " + f"is {method}" + ) + if external_gemm_to_overlap[name] not in methods["ring_exchange"]: + raise ValueError( + f"At {name}, `external` overlap method is specified, but the external gemm " + f"{external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" + ) buffer_dtype = ( torch.uint8 @@ -424,7 +450,12 @@ def add_ub( and user_ub_cfg[name]["method"] != "bulk" ): wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in user_ub_cfg + if wgrad_name in user_ub_cfg: + raise ValueError( + f"Cannot specify user UB config for '{wgrad_name}' when its " + f"corresponding dgrad '{name}' uses a non-bulk overlap method " + f"('{user_ub_cfg[name]['method']}')" + ) layers_reduce_scatter_overlap.remove(wgrad_name) layers_all_gather_overlap.remove(name) layers_reduce_scatter_overlap.append(name) @@ -451,8 +482,10 @@ def get_ub(name: str, use_fp8: bool): # So favour simplicity until the correct design becomes clear. # This is mainly an internal API so we don't need to worry about future changes key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) - assert _ub_communicators is not None, "UB manager is not initialized." - assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." + if _ub_communicators is None: + raise RuntimeError("UB manager is not initialized.") + if key not in _ub_communicators: + raise KeyError(f"UB for {name} with use_fp8={use_fp8} is not registered.") return _ub_communicators[key] @@ -608,7 +641,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self, name: Optional[str] = None) -> None: super().__init__() - assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + if not torch.cuda.is_available(): + raise RuntimeError("TransformerEngine needs CUDA.") self.name = name self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False @@ -694,9 +728,12 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> ] for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): if buffer_key in FP8GlobalStateManager.global_amax_buffer: - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." + if buffer_key not in FP8GlobalStateManager.global_amax_history_buffer: + raise RuntimeError( + f"TE internal error during amax history change: " + f"buffer_key '{buffer_key}' found in global_amax_buffer " + f"but missing from global_amax_history_buffer" + ) FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ meta_key ].amax_history[0] @@ -745,10 +782,11 @@ def _update_weight_quantizers(self) -> None: """Update the quantizers for the weight tensors.""" weight_tensors = self._get_weight_tensors() weight_quantizers = self._get_weight_quantizers() - assert len(weight_tensors) == len(weight_quantizers), ( - f"Number of weight tensors ({len(weight_tensors)}) and quantizers " - f"({len(weight_quantizers)}) must match" - ) + if len(weight_tensors) != len(weight_quantizers): + raise ValueError( + f"Number of weight tensors ({len(weight_tensors)}) and quantizers " + f"({len(weight_quantizers)}) must match" + ) for weight, quantizer in zip(weight_tensors, weight_quantizers): if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_quantizer(quantizer) @@ -796,7 +834,11 @@ def reset(key): torch.zeros_like(self.fp8_meta[key].amax_history) ) else: - assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + if key not in fp8_meta_tensors: + raise KeyError( + f"Cannot reset fp8 tensors: key '{key}' not found in fp8_meta_tensors. " + f"Available keys: {list(fp8_meta_tensors.keys())}" + ) self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) @@ -938,10 +980,11 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: if not self.allow_different_data_and_param_types: for name, param in self.named_parameters(): if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) + if dtype != param.dtype: + raise TypeError( + "Data types for parameters must match when outside of autocasted " + f"region. Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) self.fast_setattr("activation_dtype", dtype) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -1046,10 +1089,17 @@ def prepare_forward( delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: - assert inp.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise RuntimeError( + f"TransformerEngine needs CUDA. Got input on device: {inp.device}" + ) if self.tp_size > 1: - assert self.tp_group_initialized, "TP group not initialized." + if not self.tp_group_initialized: + raise RuntimeError( + "Tensor parallel group not initialized. Call " + "set_tensor_parallel_group() before forward pass when tp_size > 1." + ) self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) @@ -1058,10 +1108,11 @@ def prepare_forward( delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() if delayed_scaling_recipe: if self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) + if not self.fp8_meta["recipe"].reduce_amax: + raise ValueError( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) if not FP8GlobalStateManager.fp8_graph_capturing(): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 5beeed1262..8eb3e5823a 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -42,10 +42,20 @@ def forward( return inp, torch.tensor([], device=inp.device) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert index.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError( + f"inp must be a CUDA tensor, but got tensor on {inp.device}." + ) + if not index.is_cuda: + raise ValueError( + f"index must be a CUDA tensor, but got tensor on {index.device}." + ) # Shape check - assert inp.size(0) == index.size(0), "Permute not possible" + if inp.size(0) != index.size(0): + raise ValueError( + f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " + f"index.size(0) ({index.size(0)})." + ) # Data type check dtype = TE_DType[inp.dtype] @@ -119,7 +129,10 @@ def forward( # None probs check if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError( + f"probs must be a CUDA tensor, but got tensor on {probs.device}." + ) if probs.dtype != torch.float32: warnings.warn( @@ -136,8 +149,14 @@ def forward( probs = torch.empty(0) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError( + f"inp must be a CUDA tensor, but got tensor on {inp.device}." + ) + if not row_id_map.is_cuda: + raise ValueError( + f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + ) # Data type check dtype = TE_DType[inp.dtype] @@ -198,19 +217,34 @@ def forward( ctx.probs = probs return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert routing_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError( + f"inp must be a CUDA tensor, but got tensor on {inp.device}." + ) + if not routing_map.is_cuda: + raise ValueError( + f"routing_map must be a CUDA tensor, but got tensor on {routing_map.device}." + ) if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError( + f"probs must be a CUDA tensor, but got tensor on {probs.device}." + ) if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + if not pad_offsets.is_cuda: + raise ValueError( + f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + ) - assert inp.size(0) == routing_map.size(0), "Permute not possible" + if inp.size(0) != routing_map.size(0): + raise ValueError( + f"Permute not possible: inp.size(0) ({inp.size(0)}) must match " + f"routing_map.size(0) ({routing_map.size(0)})." + ) num_tokens, hidden_size = inp.size() num_experts = routing_map.size(1) - assert ( - num_out_tokens is not None - ), "num_out_tokens must be provided to the fused permute function." + if num_out_tokens is None: + raise ValueError("num_out_tokens must be provided to the fused permute function.") row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) @@ -226,13 +260,25 @@ def forward( if blockwise_recipe: fp8_scale = inp._rowwise_scale_inv.T.contiguous() scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Input shape: ({num_tokens}, {hidden_size}), " + f"scale shape: {tuple(fp8_scale.shape)}." + ) inp = inp._rowwise_data # mxfp8 scaling elif mxfp8_recipe: fp8_scale = inp._rowwise_scale_inv.contiguous() scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Input shape: ({num_tokens}, {hidden_size}), " + f"scale shape: {tuple(fp8_scale.shape)}." + ) inp = inp._rowwise_data # per-tensor scaling elif per_tensor_recipe: @@ -318,9 +364,11 @@ def backward( probs_grad = None if ctx.needs_input_grad[0]: row_id_map, pad_offsets = ctx.saved_tensors - assert not isinstance( - permuted_act_grad, QuantizedTensor - ), "The backward of moe_permute does not support FP8." + if isinstance(permuted_act_grad, QuantizedTensor): + raise TypeError( + "The backward of moe_permute does not support FP8, but got " + f"QuantizedTensor of type {type(permuted_act_grad).__name__}." + ) act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( permuted_act_grad, row_id_map, @@ -360,17 +408,32 @@ def forward( with_probs = merging_probs is not None if with_probs: - assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + if not merging_probs.is_cuda: + raise ValueError( + f"merging_probs must be a CUDA tensor, but got tensor on " + f"{merging_probs.device}." + ) # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError( + f"inp must be a CUDA tensor, but got tensor on {inp.device}." + ) + if not row_id_map.is_cuda: + raise ValueError( + f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." + ) if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + if not pad_offsets.is_cuda: + raise ValueError( + f"pad_offsets must be a CUDA tensor, but got tensor on {pad_offsets.device}." + ) - assert not isinstance( - inp, QuantizedTensor - ), "The forward of moe_unpermute does not support FP8." + if isinstance(inp, QuantizedTensor): + raise TypeError( + "The forward of moe_unpermute does not support FP8, but got " + f"QuantizedTensor of type {type(inp).__name__}." + ) unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( inp, row_id_map, @@ -427,13 +490,23 @@ def backward(ctx, unpermuted_act_grad): fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if ctx.num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({ctx.num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Scale shape: {tuple(fp8_scale.shape)}." + ) # mxfp8 scaling elif mxfp8_recipe: fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() unpermuted_act_grad = unpermuted_act_grad._rowwise_data scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + if ctx.num_tokens != fp8_scale.shape[0]: + raise ValueError( + f"Scale and input shape mismatch: num_tokens ({ctx.num_tokens}) != " + f"fp8_scale.shape[0] ({fp8_scale.shape[0]}). " + f"Scale shape: {tuple(fp8_scale.shape)}." + ) else: raise ValueError("Unsupported FP8 recipe") else: @@ -442,9 +515,11 @@ def backward(ctx, unpermuted_act_grad): fp8_scale = None if ctx.with_probs: - assert ( - not fp8 - ), "The backward of moe_unpermute with merging probs does not support FP8." + if fp8: + raise TypeError( + "The backward of moe_unpermute with merging probs does not support FP8, " + f"but got FP8 gradient with dtype {fp8_dtype}." + ) act_grad, probs_grad = ( triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( unpermuted_act_grad, @@ -619,10 +694,12 @@ def moe_permute_and_pad_with_probs( align_size : int the alignment size for the input tensor. """ - assert ( - tokens_per_expert is not None - ), "tokens_per_expert must be provided to the fused permute padding function." - assert align_size > 0, f"align_size must be positive, got {align_size}" + if tokens_per_expert is None: + raise ValueError( + "tokens_per_expert must be provided to the fused permute padding function." + ) + if align_size <= 0: + raise ValueError(f"align_size must be positive, got {align_size}.") # Ensure tokens_per_expert is on the same device as input to avoid device transfers if tokens_per_expert.device != inp.device: @@ -713,15 +790,31 @@ def forward( if not inp.numel(): return inp, probs - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert split_sizes.is_cuda, "TransformerEngine needs CUDA." - assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." + if not inp.is_cuda: + raise ValueError( + f"inp must be a CUDA tensor, but got tensor on {inp.device}." + ) + if not split_sizes.is_cuda: + raise ValueError( + f"split_sizes must be a CUDA tensor, but got tensor on {split_sizes.device}." + ) + if not sorted_idxs.is_cuda: + raise ValueError( + f"sorted_idxs must be a CUDA tensor, but got tensor on {sorted_idxs.device}." + ) if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." + if not probs.is_cuda: + raise ValueError( + f"probs must be a CUDA tensor, but got tensor on {probs.device}." + ) num_tokens, hidden_size = inp.shape num_splits = split_sizes.size(0) - assert num_splits == sorted_idxs.size(0) + if num_splits != sorted_idxs.size(0): + raise ValueError( + f"split_sizes.size(0) ({num_splits}) must match " + f"sorted_idxs.size(0) ({sorted_idxs.size(0)})." + ) fp8 = isinstance(inp, Float8Tensor) if fp8: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..47e6d5c8dc 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -97,7 +97,8 @@ def check_recipe_support(recipe: Recipe) -> None: recipe_supported, unsupported_reason = check_fp8_block_scaling_support() elif isinstance(recipe, MXFP8BlockScaling): recipe_supported, unsupported_reason = check_mxfp8_support() - assert recipe_supported, unsupported_reason + if not recipe_supported: + raise RuntimeError(unsupported_reason) def get_default_fp8_recipe() -> Recipe: diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 05e2d22e9c..c38864162e 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -193,10 +193,17 @@ def _cast_master_weights_to_fp8_delayed_scaling( continue # If master weight is not None, start_offset must be a valid value. - assert start_offset is not None - assert start_offset >= 0 + assert start_offset is not None, ( + "start_offset must not be None when master_weight is provided" + ) + assert start_offset >= 0, ( + f"start_offset must be non-negative, got {start_offset}" + ) end_offset = start_offset + master_weight.numel() - assert end_offset <= model_weight.numel() + assert end_offset <= model_weight.numel(), ( + f"end_offset ({end_offset}) exceeds model_weight numel ({model_weight.numel()}), " + f"start_offset={start_offset}, master_weight numel={master_weight.numel()}" + ) # master_weight may be smaller than model_weight because it could be distributed across # multiple ranks. So we need to create a dummy weight using the raw data from model_weight. @@ -280,9 +287,18 @@ def _cast_master_weights_to_fp8_current_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() - assert quantizer.dtype == fp8_dtype - assert quantizer.force_pow_2_scales == force_pow_2_scales - assert quantizer.amax_epsilon == amax_epsilon + assert quantizer.dtype == fp8_dtype, ( + f"All model weights must have the same fp8 dtype, " + f"expected {fp8_dtype} but got {quantizer.dtype}" + ) + assert quantizer.force_pow_2_scales == force_pow_2_scales, ( + f"All model weights must have the same force_pow_2_scales, " + f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" + ) + assert quantizer.amax_epsilon == amax_epsilon, ( + f"All model weights must have the same amax_epsilon, " + f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" + ) scales.append(quantizer.scale.view(1)) scale_invs.append(model_weight._scale_inv.view(1)) @@ -396,19 +412,41 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() - assert block_len == quantizer.block_len - assert fp8_dtype == quantizer.dtype - assert force_pow_2_scales == quantizer.force_pow_2_scales - assert amax_epsilon == quantizer.amax_epsilon + assert block_len == quantizer.block_len, ( + f"All model weights must have the same block_len, " + f"expected {block_len} but got {quantizer.block_len}" + ) + assert fp8_dtype == quantizer.dtype, ( + f"All model weights must have the same fp8 dtype, " + f"expected {fp8_dtype} but got {quantizer.dtype}" + ) + assert force_pow_2_scales == quantizer.force_pow_2_scales, ( + f"All model weights must have the same force_pow_2_scales, " + f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" + ) + assert amax_epsilon == quantizer.amax_epsilon, ( + f"All model weights must have the same amax_epsilon, " + f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" + ) scale_shape = quantizer.get_scale_shape(model_weight.shape, False) amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) scale = torch.empty(scale_shape, dtype=torch.float32, device=device) scale_inv = model_weight._rowwise_scale_inv - assert len(scale_shape) == 2 - assert len(scale_inv.shape) == 2 - assert scale_inv.shape[0] == scale_shape[0] - assert scale_inv.shape[1] == scale_shape[1] + assert len(scale_shape) == 2, ( + f"scale_shape must be 2D, got {len(scale_shape)}D shape {scale_shape}" + ) + assert len(scale_inv.shape) == 2, ( + f"scale_inv must be 2D, got {len(scale_inv.shape)}D shape {scale_inv.shape}" + ) + assert scale_inv.shape[0] == scale_shape[0], ( + f"scale_inv dim 0 mismatch: scale_inv.shape={scale_inv.shape}, " + f"scale_shape={scale_shape}" + ) + assert scale_inv.shape[1] == scale_shape[1], ( + f"scale_inv dim 1 mismatch: scale_inv.shape={scale_inv.shape}, " + f"scale_shape={scale_shape}" + ) amaxes.append(amax) scales.append(scale) @@ -416,7 +454,10 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: - assert len(model_weight.shape) == 2 + assert len(model_weight.shape) == 2, ( + f"model_weight must be 2D for blockwise scaling, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.fp8_block_scaling_compute_partial_amax( master_weight, amax, h, w, start_offset, block_len @@ -467,7 +508,10 @@ def _cast_master_weights_to_fp8_blockwise_scaling( end_offset = start_offset + master_weight.numel() if not use_fsdp_shard_model_weights: model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] - assert len(model_weight.shape) == 2 + assert len(model_weight.shape) == 2, ( + f"model_weight must be 2D for blockwise scaling partial cast, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.fp8_block_scaling_partial_cast( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype @@ -500,9 +544,15 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( cu_colwise_amax_sizes = [0] for model_weight, _, _, _ in params: rowwise_shape = model_weight._rowwise_scale_inv.shape - assert len(rowwise_shape) == 2 + assert len(rowwise_shape) == 2, ( + f"rowwise_scale_inv must be 2D, " + f"got {len(rowwise_shape)}D shape {rowwise_shape}" + ) colwise_shape = model_weight._columnwise_scale_inv.shape - assert len(colwise_shape) == 2 + assert len(colwise_shape) == 2, ( + f"columnwise_scale_inv must be 2D, " + f"got {len(colwise_shape)}D shape {colwise_shape}" + ) cu_rowwise_amax_sizes.append( cu_rowwise_amax_sizes[-1] + rowwise_shape[0] * rowwise_shape[1] ) @@ -541,7 +591,10 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: - assert len(model_weight.shape) == 2 + assert len(model_weight.shape) == 2, ( + f"model_weight must be 2D for MXFP8 scaling, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.mxfp8_scaling_compute_partial_amax( master_weight, amax_rowwise, amax_colwise, h, w, start_offset @@ -585,7 +638,10 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( else: rowwise_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] colwise_fragment = model_weight._columnwise_data.reshape(-1)[start_offset:end_offset] - assert len(model_weight.shape) == 2 + assert len(model_weight.shape) == 2, ( + f"model_weight must be 2D for MXFP8 scaling partial cast, " + f"got {len(model_weight.shape)}D shape {model_weight.shape}" + ) h, w = model_weight.shape tex.mxfp8_scaling_partial_cast( master_weight, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index cf7ce5e1a4..1407205adb 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -373,23 +373,35 @@ def __init__( self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm if parallel_attention_mlp: - assert self.layer_type == "encoder", "parallel_attention requires layer_type='encoder'" - assert not self.apply_residual_connection_post_layernorm, ( - "parallel_attention and apply_residual_connection_post_layernorm " - "not supported simultaneously." - ) - assert ( - not self.output_layernorm - ), "parallel_attention and output_layernorm not supported simultaneously" + if self.layer_type != "encoder": + raise ValueError( + f"parallel_attention requires layer_type='encoder', " + f"but got layer_type={self.layer_type!r}" + ) + if self.apply_residual_connection_post_layernorm: + raise ValueError( + "parallel_attention and apply_residual_connection_post_layernorm " + "are not supported simultaneously." + ) + if self.output_layernorm: + raise ValueError( + "parallel_attention and output_layernorm are not supported simultaneously." + ) self.parallel_attention_mlp = parallel_attention_mlp - assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" + if layer_type not in LayerTypes: + raise ValueError( + f"layer_type {layer_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in LayerTypes)}" + ) if not fuse_qkv_params: - assert ( - not fuse_wgrad_accumulation - ), "Gradient accumulation fusion requires single QKV parameter." + if fuse_wgrad_accumulation: + raise ValueError( + "Gradient accumulation fusion (fuse_wgrad_accumulation=True) " + "requires fuse_qkv_params=True, but fuse_qkv_params is False." + ) if not fuse_qkv_params: qkv_weight_interleaved = False @@ -792,32 +804,60 @@ def forward( }: enc_dec_bottom_right_diagonal = True - assert ( - self_attn_mask_type in AttnMaskTypes - ), f"self_attn_mask_type {self_attn_mask_type} not supported" - assert ( - enc_dec_attn_mask_type in AttnMaskTypes - ), f"enc_dec_attn_mask_type {enc_dec_attn_mask_type} not supported" + if self_attn_mask_type not in AttnMaskTypes: + raise ValueError( + f"self_attn_mask_type {self_attn_mask_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in AttnMaskTypes)}" + ) + if enc_dec_attn_mask_type not in AttnMaskTypes: + raise ValueError( + f"enc_dec_attn_mask_type {enc_dec_attn_mask_type!r} is not supported. " + f"Supported types are: {', '.join(repr(t) for t in AttnMaskTypes)}" + ) hidden_states = hidden_states.contiguous() if self.sequence_parallel and self.seq_length is not None: - assert ( - hidden_states.shape[0] == self.seq_length // self.tp_size - ), "Sequence dimension must be split across TP group when using sequence parallel." + if hidden_states.shape[0] != self.seq_length // self.tp_size: + raise ValueError( + f"Sequence dimension must be split across TP group when using " + f"sequence parallel. Expected hidden_states.shape[0] to be " + f"{self.seq_length // self.tp_size} " + f"(seq_length={self.seq_length} // tp_size={self.tp_size}), " + f"but got {hidden_states.shape[0]}." + ) if ( "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" ) and attention_mask is not None: - assert all( + if not all( attention_mask[i].dtype == torch.bool for i in range(len(attention_mask)) - ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors" + ): + non_bool_dtypes = [ + (i, attention_mask[i].dtype) + for i in range(len(attention_mask)) + if attention_mask[i].dtype != torch.bool + ] + raise TypeError( + f"Attention mask must be a boolean tensor or a list/tuple of boolean " + f"tensors, but found non-bool dtypes at indices: {non_bool_dtypes}" + ) if ( "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" ) and enc_dec_attn_mask is not None: - assert all( - enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) - ), "Encoder-decoder attention mask must be boolean tensor(s)" + if not all( + enc_dec_attn_mask[i].dtype == torch.bool + for i in range(len(enc_dec_attn_mask)) + ): + non_bool_dtypes = [ + (i, enc_dec_attn_mask[i].dtype) + for i in range(len(enc_dec_attn_mask)) + if enc_dec_attn_mask[i].dtype != torch.bool + ] + raise TypeError( + f"Encoder-decoder attention mask must be boolean tensor(s), " + f"but found non-bool dtypes at indices: {non_bool_dtypes}" + ) # For AMP if torch.is_autocast_enabled(): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 47af9fabe1..9607bea96f 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -147,7 +147,8 @@ def compare_tensors(a: torch.Tensor, b: torch.Tensor) -> None: def ensure_divisibility(numerator: int, denominator: int) -> None: """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" + if numerator % denominator != 0: + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator: int, denominator: int) -> int: @@ -271,13 +272,16 @@ def forward( @staticmethod def backward(ctx, *grad_outputs): # pylint: disable=missing-function-docstring - assert len(grad_outputs) > 0, "No gradients received for backprop!" + if len(grad_outputs) == 0: + raise RuntimeError("No gradients received for backprop!") if isinstance(ctx.split_size_or_sections, (list, tuple)): split_sizes = ctx.split_size_or_sections - assert len(grad_outputs) == len( - split_sizes - ), "Unequal number of gradients vs split sections for backprop!" + if len(grad_outputs) != len(split_sizes): + raise RuntimeError( + f"Unequal number of gradients ({len(grad_outputs)}) vs " + f"split sections ({len(split_sizes)}) for backprop!" + ) if isinstance(ctx.split_size_or_sections, int): split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) dims = len(grad_outputs[0].shape) @@ -371,7 +375,10 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: """Checks if passed in param function has everything required for tensor/model and sequence parallel. """ - assert callable(get_rng_tracker), "get_rng_tracker is not a valid function" + if not callable(get_rng_tracker): + raise TypeError( + f"get_rng_tracker must be callable, got {type(get_rng_tracker).__name__}" + ) rng_tracker = None try: @@ -379,15 +386,15 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: except Exception as e: raise RuntimeError("Cannot call get_rng_tracker function") from e - assert hasattr(rng_tracker, "get_states") and callable( - rng_tracker.get_states - ), "rng_tracker object does not have valid method get_states" - assert hasattr(rng_tracker, "set_states") and callable( - rng_tracker.set_states - ), "rng_tracker object does not have valid method set_states" - assert hasattr(rng_tracker, "fork") and callable( - rng_tracker.fork - ), "rng_tracker object does not have valid method fork" + for method_name in ("get_states", "set_states", "fork"): + if not hasattr(rng_tracker, method_name) or not callable( + getattr(rng_tracker, method_name) + ): + raise TypeError( + f"rng_tracker object ({type(rng_tracker).__name__}) does not have " + f"a valid callable method '{method_name}'. " + "Required methods: get_states, set_states, fork." + ) validate_ctx_manager(rng_tracker.fork) @@ -398,11 +405,12 @@ def assert_viewless_tensor(tensor: torch.Tensor, extra_msg: Optional[str] = None return [assert_viewless_tensor(t) for t in tensor] if not isinstance(tensor, torch.Tensor): return tensor - assert tensor._base is None, ( - "Ensure tensor._base is None before setting tensor.data or storing " - "tensor to memory buffer. Otherwise, a memory leak will occur (and " - f"likely accumulate over iterations). {extra_msg}" - ) + if tensor._base is not None: + raise ValueError( + "Ensure tensor._base is None before setting tensor.data or storing " + "tensor to memory buffer. Otherwise, a memory leak will occur (and " + f"likely accumulate over iterations). {extra_msg}" + ) return tensor @@ -440,11 +448,13 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" for tensor in tensors: - assert math.prod(tensor.shape[:-1]) % 8 == 0 and tensor.shape[-1] % 16 == 0, ( - "FP8 execution requires the product of all dimensions except the last to be divisible" - " by 8 and the last dimension to be divisible by 16, but got tensor with" - f" dims={list(tensor.size())}" - ) + if math.prod(tensor.shape[:-1]) % 8 != 0 or tensor.shape[-1] % 16 != 0: + raise ValueError( + "FP8 execution requires the product of all dimensions except the last to be" + " divisible by 8 and the last dimension to be divisible by 16, but got tensor" + f" with dims={list(tensor.size())} (product of leading dims =" + f" {math.prod(tensor.shape[:-1])}, last dim = {tensor.shape[-1]})" + ) def assert_dim_for_all_gather( @@ -452,9 +462,12 @@ def assert_dim_for_all_gather( ) -> None: """Assert that tensor dimensions are supported for all-gather""" if with_all_gather: - assert quantizer.is_quantizable(tensor), ( - "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ - ) + if not quantizer.is_quantizable(tensor): + raise ValueError( + f"All-gather requires a quantizable tensor for quantizer" + f" {quantizer.__class__.__name__}, but got tensor with" + f" shape={list(tensor.shape)} and dtype={tensor.dtype}" + ) def is_bf16_compatible() -> bool: @@ -752,7 +765,11 @@ def __cuda_array_interface__(self): def torch_dtype_to_np_typestr(self): """Convert PyTorch dtype to numpy typestr.""" ret = _torch_dtype_to_np_typestr_dict.get(self.dtype) - assert ret is not None, f"Unsupported dtype: {self.dtype}" + if ret is None: + supported = ", ".join(str(d) for d in _torch_dtype_to_np_typestr_dict) + raise TypeError( + f"Unsupported dtype: {self.dtype}. Supported dtypes: {supported}" + ) return ret @@ -791,4 +808,7 @@ def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torc return x if x is None: return None - raise TypeError(f"Invalid type {type(x)} to make weak ref") + raise TypeError( + f"Invalid type {type(x).__name__} to make weak ref. " + "Valid types are: torch.Tensor, tuple, list, dict, int, float, bool, and None." + ) From ffdb3896663d4ee7322bc5d2b01d01c02f457fe3 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 25 Feb 2026 11:53:45 -0800 Subject: [PATCH 2/3] Update transformer_engine/pytorch/distributed.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak --- transformer_engine/pytorch/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 2e7c53b9d3..96241e7509 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1870,7 +1870,7 @@ def symmetric_all_reduce( - The second element is the async work handle if async_op=True, otherwise None. """ - if async_op is not False: + if async_op: raise RuntimeError( f"Async symmetric ops are not supported yet, but async_op={async_op!r} was passed." ) From 221d723306208194e69c1fb27c480f6b29a54217 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:53:46 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/common.h | 431 +++++++++--------- .../common/fused_attn/fused_attn.cpp | 114 +++-- .../common/fused_router/utils.h | 84 ++-- .../common/gemm/cutlass_grouped_gemm.cuh | 9 +- .../common/normalization/common.cpp | 10 +- .../quantize_transpose_square_blockwise.cu | 3 +- .../jax/cpp_extensions/attention.py | 95 +++- transformer_engine/jax/cpp_extensions/gemm.py | 76 ++- .../jax/cpp_extensions/normalization.py | 95 +++- transformer_engine/jax/flax/transformer.py | 78 ++-- transformer_engine/jax/layernorm.py | 2 +- .../pytorch/cpp_extensions/fused_attn.py | 26 +- transformer_engine/pytorch/cpu_offload.py | 6 +- transformer_engine/pytorch/csrc/common.cpp | 3 +- transformer_engine/pytorch/csrc/quantizer.cpp | 5 +- .../pytorch/custom_recipes/gemm.py | 12 +- .../custom_recipes/quantization_nvfp4.py | 41 +- transformer_engine/pytorch/distributed.py | 20 +- transformer_engine/pytorch/graph.py | 24 +- transformer_engine/pytorch/module/base.py | 17 +- transformer_engine/pytorch/permutation.py | 38 +- transformer_engine/pytorch/tensor/utils.py | 72 ++- transformer_engine/pytorch/transformer.py | 17 +- transformer_engine/pytorch/utils.py | 14 +- 24 files changed, 718 insertions(+), 574 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 4e9800a69d..7f3ba31b90 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -606,154 +606,149 @@ struct TypeInfo { #define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing #endif -#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kByte: { \ - using type = unsigned char; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt16: { \ - using type = int16_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt32: { \ - using type = int32_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt64: { \ - using type = int64_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E8M0: { \ - using type = byte; \ - { __VA_ARGS__ } \ - } break; \ - SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ - default: \ - NVTE_ERROR("Unsupported dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Byte, Int32, Int64, Float32, " \ - "Float16, BFloat16, Float8E4M3, Float8E5M2, " \ - "Float8E8M0, Float4E2M1."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kByte: { \ + using type = unsigned char; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt16: { \ + using type = int16_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E8M0: { \ + using type = byte; \ + { __VA_ARGS__ } \ + } break; \ + SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Byte, Int32, Int64, Float32, " \ + "Float16, BFloat16, Float8E4M3, Float8E5M2, " \ + "Float8E8M0, Float4E2M1."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Unsupported dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float32, Float16, BFloat16, " \ - "Float8E4M3, Float8E5M2."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FLOAT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E4M3, Float8E5M2."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Unsupported output dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float32, Float16, BFloat16, " \ - "Float8E5M2, Float8E4M3."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported output dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16, " \ + "Float8E5M2, Float8E4M3."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Unsupported dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float32, Float16, BFloat16."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Unsupported dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float32, BFloat16."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, BFloat16."); \ } // Add a pack_size argument to select the packed type for FP4 @@ -765,94 +760,90 @@ struct TypeInfo { { __VA_ARGS__ } \ } break; \ default: \ - NVTE_ERROR("Unsupported dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected: Float4E2M1."); \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected: Float4E2M1."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Unsupported dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float8E5M2, Float8E4M3."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float8E5M2, Float8E4M3."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat8E5M2: \ - case DType::kFloat8E4M3: { \ - NVTE_ERROR("FP8 dtype ", \ - to_string(static_cast(dtype)), \ - " is not instantiated for input. " \ - "Expected one of: Float32, Float16, BFloat16."); \ - } break; \ - case DType::kFloat4E2M1: { \ - NVTE_ERROR("FP4 dtype Float4E2M1 is not instantiated " \ - "for input. Expected one of: Float32, Float16, " \ - "BFloat16."); \ - } break; \ - default: \ - NVTE_ERROR("Unsupported input dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float32, Float16, BFloat16."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat8E5M2: \ + case DType::kFloat8E4M3: { \ + NVTE_ERROR("FP8 dtype ", to_string(static_cast(dtype)), \ + " is not instantiated for input. " \ + "Expected one of: Float32, Float16, BFloat16."); \ + } break; \ + case DType::kFloat4E2M1: { \ + NVTE_ERROR( \ + "FP4 dtype Float4E2M1 is not instantiated " \ + "for input. Expected one of: Float32, Float16, " \ + "BFloat16."); \ + } break; \ + default: \ + NVTE_ERROR("Unsupported input dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat16: { \ - using type = fp16; \ - __VA_ARGS__; \ - break; \ - } \ - case DType::kBFloat16: { \ - using type = bf16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - NVTE_ERROR("Unsupported 16-bit dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float16, BFloat16."); \ +#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using type = fp16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kBFloat16: { \ + using type = bf16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + NVTE_ERROR("Unsupported 16-bit dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float16, BFloat16."); \ } -#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ - switch (SCALE_DIM) { \ - case 1: { \ - constexpr size_t DIM = 1; \ - { __VA_ARGS__ } \ - } break; \ - case 32: { \ - constexpr size_t DIM = 32; \ - { __VA_ARGS__ } \ - } break; \ - default: { \ - NVTE_ERROR("Unsupported MX scaling factor dimension ", \ - SCALE_DIM, ". Expected one of: 1, 32."); \ - } \ +#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported MX scaling factor dimension ", SCALE_DIM, \ + ". Expected one of: 1, 32."); \ + } \ } #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 1d94c61fc2..6a136c67e4 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -19,45 +19,79 @@ namespace transformer_engine { std::string to_string(NVTE_QKV_Layout layout) { switch (layout) { - case NVTE_SB3HD: return "NVTE_SB3HD"; - case NVTE_SBH3D: return "NVTE_SBH3D"; - case NVTE_SBHD_SB2HD: return "NVTE_SBHD_SB2HD"; - case NVTE_SBHD_SBH2D: return "NVTE_SBHD_SBH2D"; - case NVTE_SBHD_SBHD_SBHD: return "NVTE_SBHD_SBHD_SBHD"; - case NVTE_BS3HD: return "NVTE_BS3HD"; - case NVTE_BSH3D: return "NVTE_BSH3D"; - case NVTE_BSHD_BS2HD: return "NVTE_BSHD_BS2HD"; - case NVTE_BSHD_BSH2D: return "NVTE_BSHD_BSH2D"; - case NVTE_BSHD_BSHD_BSHD: return "NVTE_BSHD_BSHD_BSHD"; - case NVTE_T3HD: return "NVTE_T3HD"; - case NVTE_TH3D: return "NVTE_TH3D"; - case NVTE_THD_T2HD: return "NVTE_THD_T2HD"; - case NVTE_THD_TH2D: return "NVTE_THD_TH2D"; - case NVTE_THD_THD_THD: return "NVTE_THD_THD_THD"; - case NVTE_SBHD_BSHD_BSHD: return "NVTE_SBHD_BSHD_BSHD"; - case NVTE_BSHD_SBHD_SBHD: return "NVTE_BSHD_SBHD_SBHD"; - case NVTE_THD_BSHD_BSHD: return "NVTE_THD_BSHD_BSHD"; - case NVTE_THD_SBHD_SBHD: return "NVTE_THD_SBHD_SBHD"; - case NVTE_Paged_KV_BSHD_BSHD_BSHD: return "NVTE_Paged_KV_BSHD_BSHD_BSHD"; - case NVTE_Paged_KV_BSHD_SBHD_SBHD: return "NVTE_Paged_KV_BSHD_SBHD_SBHD"; - case NVTE_Paged_KV_SBHD_BSHD_BSHD: return "NVTE_Paged_KV_SBHD_BSHD_BSHD"; - case NVTE_Paged_KV_SBHD_SBHD_SBHD: return "NVTE_Paged_KV_SBHD_SBHD_SBHD"; - case NVTE_Paged_KV_THD_BSHD_BSHD: return "NVTE_Paged_KV_THD_BSHD_BSHD"; - case NVTE_Paged_KV_THD_SBHD_SBHD: return "NVTE_Paged_KV_THD_SBHD_SBHD"; - default: return "UNKNOWN_QKV_LAYOUT(" + std::to_string(static_cast(layout)) + ")"; + case NVTE_SB3HD: + return "NVTE_SB3HD"; + case NVTE_SBH3D: + return "NVTE_SBH3D"; + case NVTE_SBHD_SB2HD: + return "NVTE_SBHD_SB2HD"; + case NVTE_SBHD_SBH2D: + return "NVTE_SBHD_SBH2D"; + case NVTE_SBHD_SBHD_SBHD: + return "NVTE_SBHD_SBHD_SBHD"; + case NVTE_BS3HD: + return "NVTE_BS3HD"; + case NVTE_BSH3D: + return "NVTE_BSH3D"; + case NVTE_BSHD_BS2HD: + return "NVTE_BSHD_BS2HD"; + case NVTE_BSHD_BSH2D: + return "NVTE_BSHD_BSH2D"; + case NVTE_BSHD_BSHD_BSHD: + return "NVTE_BSHD_BSHD_BSHD"; + case NVTE_T3HD: + return "NVTE_T3HD"; + case NVTE_TH3D: + return "NVTE_TH3D"; + case NVTE_THD_T2HD: + return "NVTE_THD_T2HD"; + case NVTE_THD_TH2D: + return "NVTE_THD_TH2D"; + case NVTE_THD_THD_THD: + return "NVTE_THD_THD_THD"; + case NVTE_SBHD_BSHD_BSHD: + return "NVTE_SBHD_BSHD_BSHD"; + case NVTE_BSHD_SBHD_SBHD: + return "NVTE_BSHD_SBHD_SBHD"; + case NVTE_THD_BSHD_BSHD: + return "NVTE_THD_BSHD_BSHD"; + case NVTE_THD_SBHD_SBHD: + return "NVTE_THD_SBHD_SBHD"; + case NVTE_Paged_KV_BSHD_BSHD_BSHD: + return "NVTE_Paged_KV_BSHD_BSHD_BSHD"; + case NVTE_Paged_KV_BSHD_SBHD_SBHD: + return "NVTE_Paged_KV_BSHD_SBHD_SBHD"; + case NVTE_Paged_KV_SBHD_BSHD_BSHD: + return "NVTE_Paged_KV_SBHD_BSHD_BSHD"; + case NVTE_Paged_KV_SBHD_SBHD_SBHD: + return "NVTE_Paged_KV_SBHD_SBHD_SBHD"; + case NVTE_Paged_KV_THD_BSHD_BSHD: + return "NVTE_Paged_KV_THD_BSHD_BSHD"; + case NVTE_Paged_KV_THD_SBHD_SBHD: + return "NVTE_Paged_KV_THD_SBHD_SBHD"; + default: + return "UNKNOWN_QKV_LAYOUT(" + std::to_string(static_cast(layout)) + ")"; } } std::string to_string(NVTE_QKV_Format format) { switch (format) { - case NVTE_SBHD: return "NVTE_SBHD"; - case NVTE_BSHD: return "NVTE_BSHD"; - case NVTE_THD: return "NVTE_THD"; - case NVTE_BSHD_2SBHD: return "NVTE_BSHD_2SBHD"; - case NVTE_SBHD_2BSHD: return "NVTE_SBHD_2BSHD"; - case NVTE_THD_2BSHD: return "NVTE_THD_2BSHD"; - case NVTE_THD_2SBHD: return "NVTE_THD_2SBHD"; - default: return "UNKNOWN_QKV_FORMAT(" + std::to_string(static_cast(format)) + ")"; + case NVTE_SBHD: + return "NVTE_SBHD"; + case NVTE_BSHD: + return "NVTE_BSHD"; + case NVTE_THD: + return "NVTE_THD"; + case NVTE_BSHD_2SBHD: + return "NVTE_BSHD_2SBHD"; + case NVTE_SBHD_2BSHD: + return "NVTE_SBHD_2BSHD"; + case NVTE_THD_2BSHD: + return "NVTE_THD_2BSHD"; + case NVTE_THD_2SBHD: + return "NVTE_THD_2SBHD"; + default: + return "UNKNOWN_QKV_FORMAT(" + std::to_string(static_cast(format)) + ")"; } } @@ -98,8 +132,7 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; default: - NVTE_ERROR("Unsupported qkv_layout ", - transformer_engine::to_string(qkv_layout), + NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_layout_group."); } } @@ -140,8 +173,7 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_THD_2SBHD; default: - NVTE_ERROR("Unsupported qkv_layout ", - transformer_engine::to_string(qkv_layout), + NVTE_ERROR("Unsupported qkv_layout ", transformer_engine::to_string(qkv_layout), " in nvte_get_qkv_format."); } } @@ -161,8 +193,7 @@ NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD_2SBHD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("Unsupported qkv_format ", - transformer_engine::to_string(qkv_format), + NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_q_format."); } } @@ -182,8 +213,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Format::NVTE_THD: return NVTE_QKV_Format::NVTE_THD; default: - NVTE_ERROR("Unsupported qkv_format ", - transformer_engine::to_string(qkv_format), + NVTE_ERROR("Unsupported qkv_format ", transformer_engine::to_string(qkv_format), " in nvte_get_kv_format."); } } diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 7dd2cf1d13..4ff8686e6c 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -216,51 +216,49 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i } // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future -#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat16: { \ - using type = fp16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Unsupported router probs dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Float32, Float16, BFloat16.");\ +#define TE_ROUTER_PROBS_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported router probs dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Float32, Float16, BFloat16."); \ } -#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kInt32: { \ - using type = int32_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kInt64: { \ - using type = int64_t; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kBFloat16: { \ - using type = bf16; \ - { __VA_ARGS__ } \ - } break; \ - case DType::kFloat32: { \ - using type = float; \ - { __VA_ARGS__ } \ - } break; \ - default: \ - NVTE_ERROR("Unsupported router index dtype ", \ - to_string(static_cast(dtype)), \ - ". Expected one of: Int32, Int64, BFloat16, " \ - "Float32."); \ +#define TE_ROUTER_INDEX_TYPE_SWITCH_ALL(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kInt32: { \ + using type = int32_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kInt64: { \ + using type = int64_t; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Unsupported router index dtype ", to_string(static_cast(dtype)), \ + ". Expected one of: Int32, Int64, BFloat16, " \ + "Float32."); \ } } // namespace transformer_engine #endif diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh index 16b37813f1..aa2bde4203 100644 --- a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -326,20 +326,17 @@ void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, // Check can implement the kernel. if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { - NVTE_ERROR("Failed to implement CUTLASS Grouped GEMM with ", - num_gemms, " GEMMs"); + NVTE_ERROR("Failed to implement CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } // Initialize the kernel. if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { - NVTE_ERROR("Failed to initialize CUTLASS Grouped GEMM with ", - num_gemms, " GEMMs"); + NVTE_ERROR("Failed to initialize CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } // Execute the kernel in the current stream. if (gemm.run(stream) != cutlass::Status::kSuccess) { - NVTE_ERROR("Failed to run CUTLASS Grouped GEMM with ", - num_gemms, " GEMMs"); + NVTE_ERROR("Failed to run CUTLASS Grouped GEMM with ", num_gemms, " GEMMs"); } } diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index ba8b9930da..11f12775c5 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -116,8 +116,9 @@ void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* beta_dptr, void* mean_dptr, void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Backward normalization should not call the forward execute function. " - "Use the backward-specific execute overload instead."); + NVTE_ERROR( + "Backward normalization should not call the forward execute function. " + "Use the backward-specific execute overload instead."); } template @@ -166,8 +167,9 @@ void TeNormalizationPlan::execute(void* x_dptr, void* gamma void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { - NVTE_ERROR("Forward normalization should not call the backward execute function. " - "Use the forward-specific execute overload instead."); + NVTE_ERROR( + "Forward normalization should not call the backward execute function. " + "Use the forward-specific execute overload instead."); } template <> diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 9c9c43b51a..3a8536587c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -463,7 +463,8 @@ CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size std::is_same_v) { dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; } else { - NVTE_ERROR("Invalid output type for blockwise transpose (must be FP8: Float8E4M3 or Float8E5M2)."); + NVTE_ERROR( + "Invalid output type for blockwise transpose (must be FP8: Float8E4M3 or Float8E5M2)."); } CUtensorMap tensor_map_output_trans{}; diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 66e3bb8784..861ddfadae 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -165,13 +165,25 @@ def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): kv_max_seqlen = q_max_seqlen num_gqa_groups = attn_heads v_head_dim = q_head_dim - assert nqkv == 3, f"Expected nqkv == 3 for qkvpacked layout, but got nqkv={nqkv} from q_aval.shape={q_aval.shape}" + assert nqkv == 3, ( + f"Expected nqkv == 3 for qkvpacked layout, but got nqkv={nqkv} from" + f" q_aval.shape={q_aval.shape}" + ) elif qkv_layout.is_kvpacked(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, v_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape, f"Mismatched batch shapes for kvpacked layout: q_batch_shape={q_batch_shape}, kv_batch_shape={kv_batch_shape}" - assert q_head_dim == v_head_dim, f"Mismatched head dims for kvpacked layout: q_head_dim={q_head_dim}, v_head_dim={v_head_dim}" - assert nkv == 2, f"Expected nkv == 2 for kvpacked layout, but got nkv={nkv} from k_aval.shape={k_aval.shape}" + assert q_batch_shape == kv_batch_shape, ( + f"Mismatched batch shapes for kvpacked layout: q_batch_shape={q_batch_shape}," + f" kv_batch_shape={kv_batch_shape}" + ) + assert q_head_dim == v_head_dim, ( + f"Mismatched head dims for kvpacked layout: q_head_dim={q_head_dim}," + f" v_head_dim={v_head_dim}" + ) + assert nkv == 2, ( + f"Expected nkv == 2 for kvpacked layout, but got nkv={nkv} from" + f" k_aval.shape={k_aval.shape}" + ) elif qkv_layout.is_separate(): *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape *k_batch_shape, k_max_seqlen, k_num_gqa_groups, k_head_dim = k_aval.shape @@ -244,9 +256,13 @@ def check_seed(self, seed, dropout_probability, is_training): ) seed = seed.astype(self.rng_state_dtype) - assert seed.dtype == self.rng_state_dtype, f"Expected seed.dtype={self.rng_state_dtype}, but got seed.dtype={seed.dtype}" + assert ( + seed.dtype == self.rng_state_dtype + ), f"Expected seed.dtype={self.rng_state_dtype}, but got seed.dtype={seed.dtype}" # Backend takes an int64_t seed, so only the first two u32 elements are taken - assert seed.size >= self.seed_size, f"Expected seed.size >= {self.seed_size}, but got seed.size={seed.size}" + assert ( + seed.size >= self.seed_size + ), f"Expected seed.size >= {self.seed_size}, but got seed.size={seed.size}" return seed @@ -363,7 +379,9 @@ def abstract( # 32-bit unsigned int to get the buffer size we need in the C++ kernel checker = _FusedAttnRNGStateChecker() seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype, f"Expected seed_dtype={checker.rng_state_dtype}, but got seed_dtype={seed_dtype}" + assert ( + seed_dtype == checker.rng_state_dtype + ), f"Expected seed_dtype={checker.rng_state_dtype}, but got seed_dtype={seed_dtype}" rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) @@ -408,11 +426,19 @@ def abstract( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - assert softmax_offset_aval.dtype == jnp.float32, f"Expected softmax_offset_aval.dtype=float32, but got {softmax_offset_aval.dtype}" + assert ( + softmax_offset_aval.dtype == jnp.float32 + ), f"Expected softmax_offset_aval.dtype=float32, but got {softmax_offset_aval.dtype}" if config.softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: - assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), f"Expected softmax_offset_aval.shape=(1, {attn_heads}, 1, 1) for {config.softmax_type}, but got {softmax_offset_aval.shape}" + assert softmax_offset_aval.shape == (1, attn_heads, 1, 1), ( + f"Expected softmax_offset_aval.shape=(1, {attn_heads}, 1, 1) for" + f" {config.softmax_type}, but got {softmax_offset_aval.shape}" + ) else: - assert softmax_offset_aval.shape == (0,), f"Expected softmax_offset_aval.shape=(0,) for VANILLA_SOFTMAX, but got {softmax_offset_aval.shape}" + assert softmax_offset_aval.shape == (0,), ( + "Expected softmax_offset_aval.shape=(0,) for VANILLA_SOFTMAX, but got" + f" {softmax_offset_aval.shape}" + ) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval @@ -533,7 +559,9 @@ def impl( _kv_segment_pos, config: _FusedAttnConfig, ): - assert FusedAttnFwdPrimitive.inner_primitive is not None, "FusedAttnFwdPrimitive.inner_primitive has not been registered" + assert ( + FusedAttnFwdPrimitive.inner_primitive is not None + ), "FusedAttnFwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -627,7 +655,9 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnFwdPrimitive.outer_primitive is not None, "FusedAttnFwdPrimitive.outer_primitive has not been registered" + assert ( + FusedAttnFwdPrimitive.outer_primitive is not None + ), "FusedAttnFwdPrimitive.outer_primitive has not been registered" q_bdim, _, _, _, _, seed_bdim, *_ = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim @@ -778,8 +808,15 @@ def abstract( v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype, f"Mismatched dtypes: q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}, doutput_dtype={doutput_dtype}" - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, f"Mismatched seqlen dtypes: q_seqlen_or_cu_seqlen_aval.dtype={q_seqlen_or_cu_seqlen_aval.dtype}, kv_seqlen_or_cu_seqlen_aval.dtype={kv_seqlen_or_cu_seqlen_aval.dtype}" + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype, ( + f"Mismatched dtypes: q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}," + f" bias_dtype={bias_dtype}, doutput_dtype={doutput_dtype}" + ) + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, ( + "Mismatched seqlen dtypes:" + f" q_seqlen_or_cu_seqlen_aval.dtype={q_seqlen_or_cu_seqlen_aval.dtype}," + f" kv_seqlen_or_cu_seqlen_aval.dtype={kv_seqlen_or_cu_seqlen_aval.dtype}" + ) ( batch_shape, @@ -983,7 +1020,9 @@ def impl( _kv_segment_pos, config, ): - assert FusedAttnBwdPrimitive.inner_primitive is not None, "FusedAttnBwdPrimitive.inner_primitive has not been registered" + assert ( + FusedAttnBwdPrimitive.inner_primitive is not None + ), "FusedAttnBwdPrimitive.inner_primitive has not been registered" sequence_descriptor = SequenceDescriptor( seqlens=(q_seqlen, kv_seqlen), @@ -1023,7 +1062,9 @@ def convert_to_2d(offsets, batch, max_seqlen): batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( q, k, v, config.qkv_layout ) - assert len(batch) == 1, f"Expected len(batch) == 1, but got len(batch)={len(batch)}, batch={batch}" + assert ( + len(batch) == 1 + ), f"Expected len(batch) == 1, but got len(batch)={len(batch)}, batch={batch}" kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 @@ -1082,7 +1123,9 @@ def convert_to_2d(offsets, batch, max_seqlen): @staticmethod def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None, "FusedAttnBwdPrimitive.outer_primitive has not been registered" + assert ( + FusedAttnBwdPrimitive.outer_primitive is not None + ), "FusedAttnBwdPrimitive.outer_primitive has not been registered" q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim, *_ = batch_dims out_bdims = q_bdim, k_bdim, v_bdim, bias_bdim, softmax_offset_bdim @@ -3396,7 +3439,9 @@ def fused_attn_fwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None, f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: @@ -3414,10 +3459,16 @@ def fused_attn_fwd( softmax_offset, (None, HEAD_AXES, None, None) ) else: - assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX, f"Expected VANILLA_SOFTMAX when softmax_offset is None and not OFF_BY_ONE_SOFTMAX, but got softmax_type={softmax_type}" + assert softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX, ( + "Expected VANILLA_SOFTMAX when softmax_offset is None and not OFF_BY_ONE_SOFTMAX," + f" but got softmax_type={softmax_type}" + ) softmax_offset = jnp.zeros(0, dtype=jnp.float32) else: - assert softmax_offset.dtype == jnp.float32, f"Expected softmax_offset.dtype=float32, but got softmax_offset.dtype={softmax_offset.dtype}" + assert softmax_offset.dtype == jnp.float32, ( + "Expected softmax_offset.dtype=float32, but got" + f" softmax_offset.dtype={softmax_offset.dtype}" + ) # Shard by heads dimension if not VANILLA_SOFTMAX if softmax_type != AttnSoftmaxType.VANILLA_SOFTMAX: softmax_offset = with_sharding_constraint_by_logical_axes( @@ -3556,7 +3607,9 @@ def fused_attn_bwd( raise ValueError(f"Unknown {qkv_layout=}") if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None, f"bias must be None when attn_bias_type is NO_BIAS, but got bias with type={type(bias)}" + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias with type={type(bias)}" bias = jnp.zeros(0, dtype=qkv[0].dtype) if softmax_offset is None: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d6a36cf81f..d13092de55 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -166,8 +166,12 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ flatten_axis=flatten_axis, ) - assert not isinstance(lhs_q, ScaledTensor2x), f"Expected lhs_q to not be ScaledTensor2x after quantization, but got type={type(lhs_q)}" - assert not isinstance(rhs_q, ScaledTensor2x), f"Expected rhs_q to not be ScaledTensor2x after quantization, but got type={type(rhs_q)}" + assert not isinstance( + lhs_q, ScaledTensor2x + ), f"Expected lhs_q to not be ScaledTensor2x after quantization, but got type={type(lhs_q)}" + assert not isinstance( + rhs_q, ScaledTensor2x + ), f"Expected rhs_q to not be ScaledTensor2x after quantization, but got type={type(rhs_q)}" def has_rht_applied(q: AbstractBaseTensor) -> bool: return isinstance(q, ScaledTensor1x) and q.has_rht_applied @@ -529,8 +533,14 @@ def _dims_are_consecutive(dims): f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) - assert alpha.size == 1 and alpha.dtype == jnp.float32, f"Expected alpha to be a single float32 scalar, but got alpha.size={alpha.size}, alpha.dtype={alpha.dtype}" - assert beta.size == 1 and beta.dtype == jnp.float32, f"Expected beta to be a single float32 scalar, but got beta.size={beta.size}, beta.dtype={beta.dtype}" + assert alpha.size == 1 and alpha.dtype == jnp.float32, ( + f"Expected alpha to be a single float32 scalar, but got alpha.size={alpha.size}," + f" alpha.dtype={alpha.dtype}" + ) + assert beta.size == 1 and beta.dtype == jnp.float32, ( + f"Expected beta to be a single float32 scalar, but got beta.size={beta.size}," + f" beta.dtype={beta.dtype}" + ) # Declare cuBLAS workspace workspace_size = get_cublas_workspace_size_bytes() @@ -809,7 +819,9 @@ def batcher( is_outer, ): del transpose_batch_sequence, sequence_dim, is_outer - assert GemmPrimitive.outer_primitive is not None, "GemmPrimitive.outer_primitive has not been registered" + assert ( + GemmPrimitive.outer_primitive is not None + ), "GemmPrimitive.outer_primitive has not been registered" lhs_bdims, _, rhs_bdims, *_ = batch_dims # Batched GEMM is not supported @@ -1329,7 +1341,9 @@ def _te_gemm( alpha = jnp.ones((1,), jnp.float32) beta = jnp.zeros((1,), jnp.float32) if scaling_mode.is_nvfp4_scaling: - assert lhs_amax is not None and rhs_amax is not None, "NVFP4 scaling requires non-None amax for both LHS and RHS operands" + assert ( + lhs_amax is not None and rhs_amax is not None + ), "NVFP4 scaling requires non-None amax for both LHS and RHS operands" lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv @@ -1410,7 +1424,9 @@ def impl( group_sizes, num_gemms, ): - assert GroupedGemmCopySizesPrimitive.inner_primitive is not None, "GroupedGemmCopySizesPrimitive.inner_primitive has not been registered" + assert ( + GroupedGemmCopySizesPrimitive.inner_primitive is not None + ), "GroupedGemmCopySizesPrimitive.inner_primitive has not been registered" out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( group_sizes, num_gemms=num_gemms, @@ -1563,7 +1579,9 @@ def impl( is_grouped_dense_wgrad, use_async_d2h_group_sizes, ): - assert GroupedGemmPrimitive.inner_primitive is not None, "GroupedGemmPrimitive.inner_primitive has not been registered" + assert ( + GroupedGemmPrimitive.inner_primitive is not None + ), "GroupedGemmPrimitive.inner_primitive has not been registered" (out, _) = GroupedGemmPrimitive.inner_primitive.bind( lhs_data, lhs_scale_inv, @@ -1701,7 +1719,9 @@ def _jax_scaled_matmul( lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype ) if lhs.scaling_mode.is_nvfp4_scaling: - assert lhs.amax is not None and rhs.amax is not None, "NVFP4 scaling requires non-None amax for both LHS and RHS operands" + assert ( + lhs.amax is not None and rhs.amax is not None + ), "NVFP4 scaling requires non-None amax for both LHS and RHS operands" lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax) rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax) alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv @@ -1953,7 +1973,10 @@ def grouped_gemm( lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING elif isinstance(lhs, GroupedScaledTensor1x): - assert isinstance(rhs, GroupedScaledTensor1x), f"Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but got type={type(rhs)}" + assert isinstance(rhs, GroupedScaledTensor1x), ( + "Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but got" + f" type={type(rhs)}" + ) out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape rhs_shape = rhs.original_shape @@ -1961,7 +1984,10 @@ def grouped_gemm( rhs_data = rhs.data lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv - assert lhs.scaling_mode == rhs.scaling_mode, f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}, rhs.scaling_mode={rhs.scaling_mode}" + assert lhs.scaling_mode == rhs.scaling_mode, ( + f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," + f" rhs.scaling_mode={rhs.scaling_mode}" + ) scaling_mode = lhs.scaling_mode else: raise TypeError("Unsupported lhs type object!") @@ -1998,8 +2024,13 @@ def grouped_gemm( and not isinstance(rhs, ScaledTensor) and quantizer_set != noop_quantizer_set ): - assert isinstance(quantizer_set.x, GroupedQuantizer), f"Expected quantizer_set.x to be GroupedQuantizer, but got type={type(quantizer_set.x)}" - assert type(quantizer_set.x) is type(quantizer_set.kernel), f"Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" + assert isinstance( + quantizer_set.x, GroupedQuantizer + ), f"Expected quantizer_set.x to be GroupedQuantizer, but got type={type(quantizer_set.x)}" + assert type(quantizer_set.x) is type(quantizer_set.kernel), ( + "Expected quantizer_set.x and quantizer_set.kernel to have the same type, but got" + f" {type(quantizer_set.x)} and {type(quantizer_set.kernel)}" + ) scaling_mode = quantizer_set.x.scaling_mode if ( quantizer_set.x.scaling_mode.is_tensor_scaling() @@ -2065,19 +2096,30 @@ def grouped_gemm( # Calling GroupedGEMM Custom Call K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - assert K_lhs == K_rhs, f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" + assert K_lhs == K_rhs, ( + f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from" + f" lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" + ) M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G if is_grouped_dense_wgrad: N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) else: - assert group_sizes.size == rhs_shape[0], f"Expected group_sizes.size == rhs_shape[0], but got group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" + assert group_sizes.size == rhs_shape[0], ( + "Expected group_sizes.size == rhs_shape[0], but got" + f" group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" + ) - assert group_offset.size == 1, f"Expected group_offset.size == 1, but got group_offset.size={group_offset.size}" + assert ( + group_offset.size == 1 + ), f"Expected group_offset.size == 1, but got group_offset.size={group_offset.size}" has_bias = bias is not None - assert not has_bias or bias.shape == (group_sizes.size, N), f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" + assert not has_bias or bias.shape == ( + group_sizes.size, + N, + ), f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" bias = jnp.empty((), jnp.float32) if bias is None else bias (out,) = GroupedGemmPrimitive.outer_primitive.bind( diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 47e0f6166e..29292f946b 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -132,9 +132,17 @@ def abstract( ) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16], f"Unsupported x_dtype={x_dtype}, expected one of [float32, float16, bfloat16]" - assert scale_aval is None or scale_aval.dtype == jnp.float32, f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" - assert amax_aval is None or amax_aval.dtype == jnp.float32, f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" + assert x_dtype in [ + jnp.float32, + jnp.float16, + jnp.bfloat16, + ], f"Unsupported x_dtype={x_dtype}, expected one of [float32, float16, bfloat16]" + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert ( + amax_aval is None or amax_aval.dtype == jnp.float32 + ), f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" assert ( scaling_mode != ScalingMode.MXFP8_1D_SCALING.value @@ -159,7 +167,10 @@ def abstract( mu_rsigama_dtype = jnp.float32 if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.size == beta_aval.size, f"Expected gamma_aval.size == beta_aval.size, but got gamma_aval.size={gamma_aval.size}, beta_aval.size={beta_aval.size}" + assert gamma_aval.size == beta_aval.size, ( + "Expected gamma_aval.size == beta_aval.size, but got" + f" gamma_aval.size={gamma_aval.size}, beta_aval.size={beta_aval.size}" + ) assert gamma_aval.dtype == beta_aval.dtype, ( f"gamma and beta should have the same dtype, but got {gamma_aval.dtype} and " f"{beta_aval.dtype}" @@ -265,18 +276,35 @@ def lowering( del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in - assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16], f"Unsupported x_aval.dtype={x_aval.dtype}, expected one of [float32, float16, bfloat16]" - assert scale_aval is None or scale_aval.dtype == jnp.float32, f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" - assert amax_aval is None or amax_aval.dtype == jnp.float32, f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" + assert x_aval.dtype in [ + jnp.float32, + jnp.float16, + jnp.bfloat16, + ], f"Unsupported x_aval.dtype={x_aval.dtype}, expected one of [float32, float16, bfloat16]" + assert ( + scale_aval is None or scale_aval.dtype == jnp.float32 + ), f"Expected scale_aval.dtype=float32, but got scale_aval.dtype={scale_aval.dtype}" + assert ( + amax_aval is None or amax_aval.dtype == jnp.float32 + ), f"Expected amax_aval.dtype=float32, but got amax_aval.dtype={amax_aval.dtype}" g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape if norm_type == NVTE_Norm_Type.LayerNorm: - assert gamma_aval.dtype == beta_aval.dtype, f"Expected gamma and beta to have the same dtype, but got gamma_aval.dtype={gamma_aval.dtype}, beta_aval.dtype={beta_aval.dtype}" + assert gamma_aval.dtype == beta_aval.dtype, ( + "Expected gamma and beta to have the same dtype, but got" + f" gamma_aval.dtype={gamma_aval.dtype}, beta_aval.dtype={beta_aval.dtype}" + ) b_type = ir.RankedTensorType(beta.type) b_shape = b_type.shape - assert g_type == b_type, f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}, beta_type={b_type}" - assert g_shape == b_shape, f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}, beta_shape={b_shape}" + assert g_type == b_type, ( + f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}," + f" beta_type={b_type}" + ) + assert g_shape == b_shape, ( + f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}," + f" beta_shape={b_shape}" + ) sm_margin = get_forward_sm_margin() return ffi.ffi_lowering( @@ -321,7 +349,9 @@ def impl( to describe implementation """ del is_outer - assert NormFwdPrimitive.inner_primitive is not None, "NormFwdPrimitive.inner_primitive has not been registered" + assert ( + NormFwdPrimitive.inner_primitive is not None + ), "NormFwdPrimitive.inner_primitive has not been registered" ( out, colwise_out, @@ -391,7 +421,9 @@ def batcher( to describe batch rules for vmap """ check_valid_batch_dims(batch_dims) - assert NormFwdPrimitive.outer_primitive is not None, "NormFwdPrimitive.outer_primitive has not been registered" + assert ( + NormFwdPrimitive.outer_primitive is not None + ), "NormFwdPrimitive.outer_primitive has not been registered" x, scale, amax, gamma, beta = batched_args x_bdim, scale_bdim, _, _, _ = batch_dims @@ -706,13 +738,26 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, norm_type, zero_ w_dtype = dtypes.canonicalize_dtype(gamma_aval.dtype) rsigma_dtype = dtypes.canonicalize_dtype(rsigma_aval.dtype) - assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype, f"Expected dz_aval.dtype={w_dtype} (matching gamma dtype), but got dz_aval.dtype={dtypes.canonicalize_dtype(dz_aval.dtype)}" - assert dz_aval.shape == x_aval.shape, f"Expected dz_aval.shape == x_aval.shape, but got dz_aval.shape={dz_aval.shape}, x_aval.shape={x_aval.shape}" + assert dtypes.canonicalize_dtype(dz_aval.dtype) == w_dtype, ( + f"Expected dz_aval.dtype={w_dtype} (matching gamma dtype), but got" + f" dz_aval.dtype={dtypes.canonicalize_dtype(dz_aval.dtype)}" + ) + assert dz_aval.shape == x_aval.shape, ( + f"Expected dz_aval.shape == x_aval.shape, but got dz_aval.shape={dz_aval.shape}," + f" x_aval.shape={x_aval.shape}" + ) if norm_type == NVTE_Norm_Type.LayerNorm: mu_dtype = dtypes.canonicalize_dtype(mu_aval.dtype) - assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], f"Expected mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], but got mu_aval.shape={mu_aval.shape}, rsigma_aval.shape={rsigma_aval.shape}, x_aval.shape[:-1]={x_aval.shape[:-1]}" - assert mu_dtype == rsigma_dtype == jnp.float32, f"Expected mu_dtype == rsigma_dtype == float32, but got mu_dtype={mu_dtype}, rsigma_dtype={rsigma_dtype}" + assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], ( + "Expected mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1], but got" + f" mu_aval.shape={mu_aval.shape}, rsigma_aval.shape={rsigma_aval.shape}," + f" x_aval.shape[:-1]={x_aval.shape[:-1]}" + ) + assert mu_dtype == rsigma_dtype == jnp.float32, ( + f"Expected mu_dtype == rsigma_dtype == float32, but got mu_dtype={mu_dtype}," + f" rsigma_dtype={rsigma_dtype}" + ) dx_aval = dz_aval dgamma_aval = dbeta_aval = gamma_aval @@ -756,8 +801,14 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): g_shape = g_type.shape b_type = ir.RankedTensorType(gamma.type) b_shape = b_type.shape - assert g_type == b_type, f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}, beta_type={b_type}" - assert g_shape == b_shape, f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}, beta_shape={b_shape}" + assert g_type == b_type, ( + f"Expected gamma and beta to have the same IR type, but got gamma_type={g_type}," + f" beta_type={b_type}" + ) + assert g_shape == b_shape, ( + f"Expected gamma and beta to have the same shape, but got gamma_shape={g_shape}," + f" beta_shape={b_shape}" + ) sm_margin = get_backward_sm_margin() return ffi.ffi_lowering(NormBwdPrimitive.name)( @@ -774,7 +825,9 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, norm_type, zero_centered_gamma): @staticmethod def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): - assert NormBwdPrimitive.inner_primitive is not None, "NormBwdPrimitive.inner_primitive has not been registered" + assert ( + NormBwdPrimitive.inner_primitive is not None + ), "NormBwdPrimitive.inner_primitive has not been registered" dx, dgamma, dbeta, _ = NormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, norm_type=norm_type, zero_centered_gamma=zero_centered_gamma ) @@ -783,7 +836,9 @@ def impl(dz, x, mu, rsigma, gamma, norm_type, zero_centered_gamma): @staticmethod def batcher(batched_args, batch_dims, *, norm_type, zero_centered_gamma): check_valid_batch_dims(batch_dims) - assert NormBwdPrimitive.outer_primitive is not None, "NormBwdPrimitive.outer_primitive has not been registered" + assert ( + NormBwdPrimitive.outer_primitive is not None + ), "NormBwdPrimitive.outer_primitive has not been registered" dz, x, mu, rsigma, gamma = batched_args _, x_bdim, _, _, gamma_bdim = batch_dims diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 3e1828c681..513677e4a1 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -182,9 +182,9 @@ def __call__( is_gqa = h_q != h_kv if is_gqa: - assert (h_q % h_kv == 0) and (h_q >= h_kv), ( - f"num_query_heads ({h_q}) must be divisible by and >= num_kv_heads ({h_kv})" - ) + assert (h_q % h_kv == 0) and ( + h_q >= h_kv + ), f"num_query_heads ({h_q}) must be divisible by and >= num_kv_heads ({h_kv})" group_size = h_q // h_kv grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) @@ -430,7 +430,9 @@ def __call__( if self.transpose_batch_sequence: x = x.transpose([1, 0, 2, 3]) - assert x.dtype == query.dtype, f"output dtype {x.dtype} does not match query dtype {query.dtype}" + assert ( + x.dtype == query.dtype + ), f"output dtype {x.dtype} does not match query dtype {query.dtype}" return x @@ -715,9 +717,13 @@ def __call__( del self.attn_bias_type, self.attn_mask_type, self.qkv_layout if attn_bias_type == AttnBiasType.NO_BIAS: - assert bias is None, f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" + assert ( + bias is None + ), f"bias must be None when attn_bias_type is NO_BIAS, but got bias={bias}" else: - assert bias is not None, f"bias must not be None when attn_bias_type is {attn_bias_type}" + assert ( + bias is not None + ), f"bias must not be None when attn_bias_type is {attn_bias_type}" bias = bias.astype(input_dtype) self._assert_dtypes(query, key, value, qkv_layout) @@ -825,15 +831,13 @@ def __call__( key, value = jnp.split(key, [1], axis=-3) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) else: - assert qkv_layout.is_separate(), ( - f"Expected separate qkv_layout, but got {qkv_layout}" - ) + assert ( + qkv_layout.is_separate() + ), f"Expected separate qkv_layout, but got {qkv_layout}" assert sequence_descriptor is None or isinstance( sequence_descriptor, (jnp.ndarray, np.ndarray) - ), ( - f"sequence_descriptor must be None or ndarray, but got {type(sequence_descriptor)}" - ) + ), f"sequence_descriptor must be None or ndarray, but got {type(sequence_descriptor)}" x = _UnfusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -1313,7 +1317,9 @@ def query_init(*args): return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) def qkv_init(key, shape, dtype): - assert len(shape) == 3, f"qkv_init expects 3D shape, but got {len(shape)}D shape {shape}" + assert ( + len(shape) == 3 + ), f"qkv_init expects 3D shape, but got {len(shape)}D shape {shape}" assert shape[-2] == 3, f"qkv_init expects shape[-2] == 3, but got shape={shape}" q_key, k_key, v_key = jax_random.split(key, num=3) @@ -1500,9 +1506,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): elif qkv_layout == QKVLayout.BSHD_BS2HD: key, value = jnp.split(kv_proj, [1], axis=-2) else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD, ( - f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" - ) + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" # No changes to memory layout, should trigger bitcast only (Ideally no Perf impact) query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) @@ -1528,9 +1534,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) if decode: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD, ( - f"decode mode requires QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" - ) + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"decode mode requires QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) @@ -1598,9 +1604,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kv_proj = with_sharding_constraint_by_logical_axes(kv_proj, kv_sharding_constraint) dpa_args = [query, kv_proj, None] else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD, ( - f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" - ) + assert ( + qkv_layout == QKVLayout.BSHD_BSHD_BSHD + ), f"Expected QKVLayout.BSHD_BSHD_BSHD, but got {qkv_layout}" query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim)) key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) @@ -2113,7 +2119,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): l = inputs.shape[sequence_dim] attn_bias = rel_emb(l, l, False) - assert inputs.ndim == 3, f"inputs must be 3D (batch, sequence, hidden), but got {inputs.ndim}D" + assert ( + inputs.ndim == 3 + ), f"inputs must be 3D (batch, sequence, hidden), but got {inputs.ndim}D" # Make name be the exactly same as T5X, since names would affect # RNGKey during init and apply. Myabe no need in the feature. @@ -2163,9 +2171,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): )(inputs, inputs, attention_mask, attn_bias, deterministic=deterministic, decode=decode) def hidden_dropout(x, deterministic): - assert isinstance(self.hidden_dropout_dims, Sequence), ( - f"hidden_dropout_dims must be a Sequence, but got {type(self.hidden_dropout_dims)}" - ) + assert isinstance( + self.hidden_dropout_dims, Sequence + ), f"hidden_dropout_dims must be a Sequence, but got {type(self.hidden_dropout_dims)}" x_shape_len = len(x.shape) for dims in self.hidden_dropout_dims: assert -x_shape_len <= dims < x_shape_len, ( @@ -2196,9 +2204,9 @@ def hidden_dropout(x, deterministic): )(x, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None, ( - "ln_out must not be None when apply_residual_connection_post_layernorm is True" - ) + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out x = x + residual @@ -2258,9 +2266,9 @@ def hidden_dropout(x, deterministic): y = hidden_dropout(y, deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None, ( - "ln_out must not be None when apply_residual_connection_post_layernorm is True" - ) + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out mlp_input = y + residual @@ -2305,9 +2313,9 @@ def hidden_dropout(x, deterministic): )(mlp_input, deterministic=deterministic) if self.apply_residual_connection_post_layernorm: - assert ln_out is not None, ( - "ln_out must not be None when apply_residual_connection_post_layernorm is True" - ) + assert ( + ln_out is not None + ), "ln_out must not be None when apply_residual_connection_post_layernorm is True" residual = ln_out z = with_sharding_constraint_by_logical_axes( diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 83a8544256..0f173a89e3 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -34,7 +34,7 @@ def canonicalize_norm_type(x): if canonicalized not in ["layernorm", "rmsnorm"]: raise ValueError( f"Unsupported normalization type '{x}' (canonicalized: '{canonicalized}'). " - f"Valid options are: 'layernorm', 'rmsnorm'." + "Valid options are: 'layernorm', 'rmsnorm'." ) return canonicalized diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index d78d2ab1f0..8aff43b98e 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -270,16 +270,16 @@ def fused_attn_fwd( attn_scale = 1.0 / math.sqrt(d) if attn_bias_type not in ["no_bias", "alibi"]: - assert attn_bias is not None, ( - f"attn_bias tensor cannot be None when attn_bias_type={attn_bias_type!r}." - ) + assert ( + attn_bias is not None + ), f"attn_bias tensor cannot be None when attn_bias_type={attn_bias_type!r}." assert attn_bias.dtype == q.dtype, ( - f"attn_bias tensor must have the same dtype as q and kv: " + "attn_bias tensor must have the same dtype as q and kv: " f"attn_bias.dtype={attn_bias.dtype} but q.dtype={q.dtype}." ) assert fused_attention_backend != FusedAttnBackend["No_Backend"], ( - f"Fused attention does not support this input combination:" + "Fused attention does not support this input combination:" f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," f" q.dtype={q.dtype}, backend={fused_attention_backend}." @@ -300,11 +300,11 @@ def fused_attn_fwd( ) // BACKEND_F16m512_FP8_THREADS_PER_CTA assert s_quantizer is not None, ( - f"s_quantizer is required for FP8 fused attention forward" + "s_quantizer is required for FP8 fused attention forward" f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." ) assert o_quantizer is not None, ( - f"o_quantizer is required for FP8 fused attention forward" + "o_quantizer is required for FP8 fused attention forward" f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." ) else: @@ -496,7 +496,7 @@ def fused_attn_bwd( attn_scale = 1.0 / math.sqrt(d) assert fused_attention_backend != FusedAttnBackend["No_Backend"], ( - f"Fused attention backward does not support this input combination:" + "Fused attention backward does not support this input combination:" f" qkv_layout={qkv_layout!r}, attn_bias_type={attn_bias_type!r}," f" attn_mask_type={attn_mask_type!r}, q.shape={list(q.shape)}," f" q.dtype={q.dtype}, backend={fused_attention_backend}." @@ -504,26 +504,26 @@ def fused_attn_bwd( if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: assert len(aux_ctx_tensors) >= 1, ( - f"aux_ctx_tensors must contain rng_state as its last element," + "aux_ctx_tensors must contain rng_state as its last element," f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" f" for backend={fused_attention_backend}." ) if fused_attention_backend == FusedAttnBackend["FP8"]: assert s_quantizer is not None, ( - f"s_quantizer is required for FP8 fused attention backward" + "s_quantizer is required for FP8 fused attention backward" f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." ) assert dp_quantizer is not None, ( - f"dp_quantizer is required for FP8 fused attention backward" + "dp_quantizer is required for FP8 fused attention backward" f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." ) assert dqkv_dtype is not None, ( - f"dqkv_dtype is required for FP8 fused attention backward" + "dqkv_dtype is required for FP8 fused attention backward" f" (backend={fused_attention_backend}, qkv_layout={qkv_layout!r})." ) assert len(aux_ctx_tensors) == 3, ( - f"aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," + "aux_ctx_tensors must be [M, ZInv, rng_state] for FP8 fused attention," f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" f" (backend={fused_attention_backend})." ) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 164124ae29..f3b0bc776a 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -428,7 +428,7 @@ def pop_tensor( # 4. the layer was offloaded assert self.state == "reload_started", ( - f"Expected state='reload_started' when popping an offloaded tensor, " + "Expected state='reload_started' when popping an offloaded tensor, " f"but got state='{self.state}' for tensor={tensor_or_tensor_id}" ) # wait for the tensor to be reloaded @@ -893,8 +893,8 @@ def synchronization_function(self, tensor): This function is used to catch the backward pass of the model. """ assert tensor.requires_grad is True, ( - f"Tensor passed to synchronization_function must require grad to " - f"register backward hooks, but got requires_grad=False for tensor " + "Tensor passed to synchronization_function must require grad to " + "register backward hooks, but got requires_grad=False for tensor " f"with shape={tensor.shape}, dtype={tensor.dtype}" ) assert self.current_layer is not None, ( diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e6248f2dcc..b06f6f5619 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -272,8 +272,7 @@ at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType } else if (size == 1) { return at::empty({static_cast(shape.data[0])}, at::CUDA(GetATenDType(type))); } - NVTE_ERROR("Unsupported tensor allocation: ndim=", size, - ", init_to_zeros=", init_to_zeros, + NVTE_ERROR("Unsupported tensor allocation: ndim=", size, ", init_to_zeros=", init_to_zeros, ". Only 1D and 2D tensors are supported."); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7a9ac63e4e..d27480521b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1488,8 +1488,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou }); } else { // raise error since it's not supported yet - NVTE_ERROR("Pre-RHT amax is not supported yet. " - "Use with_post_rht_amax=true instead."); + NVTE_ERROR( + "Pre-RHT amax is not supported yet. " + "Use with_post_rht_amax=true instead."); } } else { // Without RHT if (compute_amax) { diff --git a/transformer_engine/pytorch/custom_recipes/gemm.py b/transformer_engine/pytorch/custom_recipes/gemm.py index 955640ec15..0c603587fc 100644 --- a/transformer_engine/pytorch/custom_recipes/gemm.py +++ b/transformer_engine/pytorch/custom_recipes/gemm.py @@ -72,7 +72,9 @@ def custom_gemm( assert sx is not None, "FPROP GEMM: activation scale (A.scale) is None" assert qw is not None, "FPROP GEMM: quantized weight data (B.data) is None" assert sw is not None, "FPROP GEMM: weight scale (B.scale) is None" - assert A.original_shape is not None, "FPROP GEMM: A.original_shape is None, cannot determine output shape" + assert ( + A.original_shape is not None + ), "FPROP GEMM: A.original_shape is None, cannot determine output shape" # Call quantizer's qgemm method result = quantizer.qgemm( @@ -115,9 +117,13 @@ def custom_gemm( elif gemm_type == GEMMType.WGRAD: qdy_t, sdy_t = A.data_t, A.scale_t qx_t, sx_t = B.data_t, B.scale_t - assert qdy_t is not None, "WGRAD GEMM: transposed quantized gradient data (A.data_t) is None" + assert ( + qdy_t is not None + ), "WGRAD GEMM: transposed quantized gradient data (A.data_t) is None" assert sdy_t is not None, "WGRAD GEMM: transposed gradient scale (A.scale_t) is None" - assert qx_t is not None, "WGRAD GEMM: transposed quantized activation data (B.data_t) is None" + assert ( + qx_t is not None + ), "WGRAD GEMM: transposed quantized activation data (B.data_t) is None" assert sx_t is not None, "WGRAD GEMM: transposed activation scale (B.scale_t) is None" result = quantizer.qgemm( diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index 7e07355bcd..a3a847500f 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -325,9 +325,9 @@ def size(self, *args, **kwargs): # pylint: disable=unused-argument the second dimension by half. This method returns the logical shape that users expect, not the internal packed storage shape. """ - assert self.original_shape is not None, ( - "NVFP4TensorRef.size() called but original_shape has not been set" - ) + assert ( + self.original_shape is not None + ), "NVFP4TensorRef.size() called but original_shape has not been set" return torch.Size(self.original_shape) @@ -448,9 +448,9 @@ def _quantize_blockwise_reference( eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.ndim == 2, ( - f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape {x.shape}" - ) + assert ( + x.ndim == 2 + ), f"_quantize_blockwise_reference expects a 2D tensor, got {x.ndim}D with shape {x.shape}" using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 m, n = x.shape # Compute vec_max based on the original x (before reshape) @@ -771,7 +771,8 @@ def is_data_t_transposed_in_memory(self) -> bool: TODO(etsykunov): Confirm docstring is correct. """ raise NotImplementedError( - "NVFP4QuantizerRef.is_data_t_transposed_in_memory is not implemented for FP4 quantization" + "NVFP4QuantizerRef.is_data_t_transposed_in_memory is not implemented for FP4" + " quantization" ) def qgemm( @@ -820,19 +821,19 @@ def qgemm( else: - assert qresult_x is not None, ( - "qresult_x is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" - ) - assert qresult_w is not None, ( - "qresult_w is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" - ) - - assert qresult_x.global_amax_row is not None, ( - "qresult_x.global_amax_row must be set for non-pow_2_scales NVFP4 GEMM" - ) - assert qresult_w.global_amax_col is not None, ( - "qresult_w.global_amax_col must be set for non-pow_2_scales NVFP4 GEMM" - ) + assert ( + qresult_x is not None + ), "qresult_x is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + assert ( + qresult_w is not None + ), "qresult_w is required for non-pow_2_scales NVFP4 GEMM (needed for global_amax)" + + assert ( + qresult_x.global_amax_row is not None + ), "qresult_x.global_amax_row must be set for non-pow_2_scales NVFP4 GEMM" + assert ( + qresult_w.global_amax_col is not None + ), "qresult_w.global_amax_col must be set for non-pow_2_scales NVFP4 GEMM" sx = sx.to(torch.float32) sw = sw.to(torch.float32) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 96241e7509..e2dde06ba4 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -155,7 +155,7 @@ def set_tensor_model_parallel_attributes( if hasattr(tensor, attribute): raise RuntimeError( f"Tensor already has attribute '{attribute}' set. Cannot set " - f"tensor model parallel attributes on a tensor that already has them." + "tensor model parallel attributes on a tensor that already has them." ) # Set the attributes. setattr(tensor, "tensor_model_parallel", is_parallel) @@ -932,7 +932,7 @@ def reduce_scatter_along_first_dim( dim_size = list(inp.size()) if dim_size[0] % world_size != 0: raise ValueError( - f"First dimension of the tensor should be divisible by tensor parallel size, " + "First dimension of the tensor should be divisible by tensor parallel size, " f"but got dim_size[0]={dim_size[0]} and world_size={world_size} " f"(remainder={dim_size[0] % world_size})." ) @@ -1002,7 +1002,7 @@ def _all_gather_fp8( if not isinstance(inp, Float8TensorStorage): if not isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): raise TypeError( - f"Expected quantizer to be Float8Quantizer or Float8CurrentScalingQuantizer " + "Expected quantizer to be Float8Quantizer or Float8CurrentScalingQuantizer " f"when input is not Float8TensorStorage, but got {type(quantizer).__name__}." ) # we cannot directly gather the transposed fp8 tensor @@ -1710,8 +1710,7 @@ def gather_along_first_dim( if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer): if not isinstance(quantizer, MXFP8Quantizer): raise TypeError( - f"Expected MXFP8Quantizer for MXFP8 all-gather, " - f"but got {type(quantizer).__name__}." + f"Expected MXFP8Quantizer for MXFP8 all-gather, but got {type(quantizer).__name__}." ) return _all_gather_mxfp8( inp, @@ -1725,8 +1724,7 @@ def gather_along_first_dim( if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer): if not isinstance(quantizer, NVFP4Quantizer): raise TypeError( - f"Expected NVFP4Quantizer for NVFP4 all-gather, " - f"but got {type(quantizer).__name__}." + f"Expected NVFP4Quantizer for NVFP4 all-gather, but got {type(quantizer).__name__}." ) return _all_gather_nvfp4( inp, @@ -2013,7 +2011,7 @@ def _fsdp_gather_tensors( if fsdp_group is not None: if len(shapes) != len(tensors): raise ValueError( - f"Number of tensors and tensor shapes must be equal, " + "Number of tensors and tensor shapes must be equal, " f"but got {len(shapes)} shapes and {len(tensors)} tensors." ) for s, t in zip(shapes, tensors): @@ -2072,9 +2070,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: FSDP-wrapped root module that may contain FSDP-wrapped TE modules. """ if not isinstance(fsdp_root, FSDP): - raise TypeError( - f"Root module must be FSDP-wrapped, but got {type(fsdp_root).__name__}." - ) + raise TypeError(f"Root module must be FSDP-wrapped, but got {type(fsdp_root).__name__}.") # If the root module is a TE module, inject FSDP information into it if _is_te_module(fsdp_root.module): @@ -2088,7 +2084,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: if root_state is None: raise RuntimeError( f"Root module ({type(fsdp_root.module).__name__}) does not have a valid " - f"_FSDPState. Ensure the module is properly wrapped with FSDP." + "_FSDPState. Ensure the module is properly wrapped with FSDP." ) fsdp_root.module.fast_setattr("fsdp_group", root_state.process_group) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index fb7b695854..824cbe3d07 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -147,11 +147,11 @@ def _make_graphed_callables( delay_wgrad_compute = False if _order is None: assert len(sample_args) == len(callables), ( - f"Expected sample_args to have the same length as callables, " + "Expected sample_args to have the same length as callables, " f"but got {len(sample_args)} sample_args for {len(callables)} callables" ) assert len(sample_kwargs) == len(callables), ( - f"Expected sample_kwargs to have the same length as callables, " + "Expected sample_kwargs to have the same length as callables, " f"but got {len(sample_kwargs)} sample_kwargs for {len(callables)} callables" ) else: @@ -178,7 +178,7 @@ def _make_graphed_callables( num_model_chunks = max(_order_without_wgrad) num_microbatches = len(_order_without_wgrad) // num_model_chunks // 2 assert num_model_chunks * num_microbatches * 2 == len(_order_without_wgrad), ( - f"Pipeline-parallel order dimension mismatch: " + "Pipeline-parallel order dimension mismatch: " f"num_model_chunks ({num_model_chunks}) * num_microbatches ({num_microbatches}) * 2 " f"= {num_model_chunks * num_microbatches * 2}, " f"but len(_order_without_wgrad) = {len(_order_without_wgrad)}" @@ -232,7 +232,7 @@ def _make_graphed_callables( _prefix_num_layers.append(_prefix_num_layers[-1] + num_layers) assert len(sample_kwargs) == len(sample_args), ( - f"Pipeline-parallel schedule requires sample_kwargs and sample_args to have " + "Pipeline-parallel schedule requires sample_kwargs and sample_args to have " f"the same length, but got {len(sample_kwargs)} sample_kwargs " f"for {len(sample_args)} sample_args" ) @@ -368,7 +368,7 @@ def _make_graphed_callables( else () ) assert len(per_callable_module_params) == len(flatten_sample_args), ( - f"Pipeline-parallel dimension mismatch: " + "Pipeline-parallel dimension mismatch: " f"per_callable_module_params has {len(per_callable_module_params)} entries, " f"but flatten_sample_args has {len(flatten_sample_args)} entries" ) @@ -819,9 +819,9 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Replay forward graph fwd_graph.replay() - assert isinstance(static_outputs, tuple), ( - f"Expected static_outputs to be a tuple, but got {type(static_outputs)}" - ) + assert isinstance( + static_outputs, tuple + ), f"Expected static_outputs to be a tuple, but got {type(static_outputs)}" return tuple(o.detach() if o is not None else o for o in static_outputs) @staticmethod @@ -831,7 +831,7 @@ def backward(ctx, *grads): # Replay backward graph assert len(grads) == len(static_grad_outputs), ( - f"Backward graph grad dimension mismatch: " + "Backward graph grad dimension mismatch: " f"received {len(grads)} grads, " f"but expected {len(static_grad_outputs)} static_grad_outputs" ) @@ -848,9 +848,9 @@ def backward(ctx, *grads): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) # Input args that didn't require grad expect a None gradient. - assert isinstance(static_grad_inputs, tuple), ( - f"Expected static_grad_inputs to be a tuple, but got {type(static_grad_inputs)}" - ) + assert isinstance( + static_grad_inputs, tuple + ), f"Expected static_grad_inputs to be a tuple, but got {type(static_grad_inputs)}" return (None,) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 03482b6ce7..6999a42fce 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -178,12 +178,11 @@ def initialize_ub( f"quantization_modes must be a list, got {type(quantization_modes).__name__}" ) invalid_modes = [ - mode for mode in quantization_modes - if not isinstance(mode, UserBufferQuantizationMode) + mode for mode in quantization_modes if not isinstance(mode, UserBufferQuantizationMode) ] if invalid_modes: raise TypeError( - f"quantization_modes must be a list of UserBufferQuantizationMode, " + "quantization_modes must be a list of UserBufferQuantizationMode, " f"got invalid entries: {invalid_modes}" ) @@ -215,9 +214,7 @@ def initialize_ub( # Bootstrapping with torch.distributed API, so check backend and construct # intra/inter-node process groups... if not torch.distributed.is_initialized(): - raise RuntimeError( - "torch.distributed must be initialized before using Userbuffers" - ) + raise RuntimeError("torch.distributed must be initialized before using Userbuffers") if bootstrap_backend is None: bootstrap_backend = "nccl" if torch.distributed.is_mpi_available(): @@ -228,7 +225,7 @@ def initialize_ub( if bootstrap_backend not in ["gloo", "mpi", "nccl"]: raise ValueError( f"Invalid torch.distributed backend '{bootstrap_backend}' for bootstrapping " - f"Userbuffers. Must be one of: 'gloo', 'mpi', 'nccl'" + "Userbuffers. Must be one of: 'gloo', 'mpi', 'nccl'" ) if not torch.distributed.is_backend_available(bootstrap_backend): raise RuntimeError( @@ -355,7 +352,7 @@ def add_ub( ) if quantization_mode != UserBufferQuantizationMode.FP8: raise ValueError( - f"Atomic GEMM overlap supported only for FP8 GEMM, " + "Atomic GEMM overlap supported only for FP8 GEMM, " f"got quantization_mode={quantization_mode}" ) if method in ("bulk", "external"): @@ -730,9 +727,9 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> if buffer_key in FP8GlobalStateManager.global_amax_buffer: if buffer_key not in FP8GlobalStateManager.global_amax_history_buffer: raise RuntimeError( - f"TE internal error during amax history change: " + "TE internal error during amax history change: " f"buffer_key '{buffer_key}' found in global_amax_buffer " - f"but missing from global_amax_history_buffer" + "but missing from global_amax_history_buffer" ) FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = self.fp8_meta[ meta_key diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 8eb3e5823a..cbdaaf84f2 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -43,13 +43,9 @@ def forward( # Device check if not inp.is_cuda: - raise ValueError( - f"inp must be a CUDA tensor, but got tensor on {inp.device}." - ) + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") if not index.is_cuda: - raise ValueError( - f"index must be a CUDA tensor, but got tensor on {index.device}." - ) + raise ValueError(f"index must be a CUDA tensor, but got tensor on {index.device}.") # Shape check if inp.size(0) != index.size(0): raise ValueError( @@ -130,9 +126,7 @@ def forward( # None probs check if probs is not None: if not probs.is_cuda: - raise ValueError( - f"probs must be a CUDA tensor, but got tensor on {probs.device}." - ) + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") if probs.dtype != torch.float32: warnings.warn( @@ -150,9 +144,7 @@ def forward( # Device check if not inp.is_cuda: - raise ValueError( - f"inp must be a CUDA tensor, but got tensor on {inp.device}." - ) + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") if not row_id_map.is_cuda: raise ValueError( f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." @@ -218,18 +210,14 @@ def forward( return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) if not inp.is_cuda: - raise ValueError( - f"inp must be a CUDA tensor, but got tensor on {inp.device}." - ) + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") if not routing_map.is_cuda: raise ValueError( f"routing_map must be a CUDA tensor, but got tensor on {routing_map.device}." ) if probs is not None: if not probs.is_cuda: - raise ValueError( - f"probs must be a CUDA tensor, but got tensor on {probs.device}." - ) + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") if pad_offsets is not None: if not pad_offsets.is_cuda: raise ValueError( @@ -410,15 +398,13 @@ def forward( if with_probs: if not merging_probs.is_cuda: raise ValueError( - f"merging_probs must be a CUDA tensor, but got tensor on " + "merging_probs must be a CUDA tensor, but got tensor on " f"{merging_probs.device}." ) # Device check if not inp.is_cuda: - raise ValueError( - f"inp must be a CUDA tensor, but got tensor on {inp.device}." - ) + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") if not row_id_map.is_cuda: raise ValueError( f"row_id_map must be a CUDA tensor, but got tensor on {row_id_map.device}." @@ -791,9 +777,7 @@ def forward( return inp, probs if not inp.is_cuda: - raise ValueError( - f"inp must be a CUDA tensor, but got tensor on {inp.device}." - ) + raise ValueError(f"inp must be a CUDA tensor, but got tensor on {inp.device}.") if not split_sizes.is_cuda: raise ValueError( f"split_sizes must be a CUDA tensor, but got tensor on {split_sizes.device}." @@ -804,9 +788,7 @@ def forward( ) if probs is not None: if not probs.is_cuda: - raise ValueError( - f"probs must be a CUDA tensor, but got tensor on {probs.device}." - ) + raise ValueError(f"probs must be a CUDA tensor, but got tensor on {probs.device}.") num_tokens, hidden_size = inp.shape num_splits = split_sizes.size(0) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index c38864162e..04056583e6 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -193,12 +193,10 @@ def _cast_master_weights_to_fp8_delayed_scaling( continue # If master weight is not None, start_offset must be a valid value. - assert start_offset is not None, ( - "start_offset must not be None when master_weight is provided" - ) - assert start_offset >= 0, ( - f"start_offset must be non-negative, got {start_offset}" - ) + assert ( + start_offset is not None + ), "start_offset must not be None when master_weight is provided" + assert start_offset >= 0, f"start_offset must be non-negative, got {start_offset}" end_offset = start_offset + master_weight.numel() assert end_offset <= model_weight.numel(), ( f"end_offset ({end_offset}) exceeds model_weight numel ({model_weight.numel()}), " @@ -288,15 +286,15 @@ def _cast_master_weights_to_fp8_current_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() assert quantizer.dtype == fp8_dtype, ( - f"All model weights must have the same fp8 dtype, " + "All model weights must have the same fp8 dtype, " f"expected {fp8_dtype} but got {quantizer.dtype}" ) assert quantizer.force_pow_2_scales == force_pow_2_scales, ( - f"All model weights must have the same force_pow_2_scales, " + "All model weights must have the same force_pow_2_scales, " f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" ) assert quantizer.amax_epsilon == amax_epsilon, ( - f"All model weights must have the same amax_epsilon, " + "All model weights must have the same amax_epsilon, " f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" ) @@ -413,19 +411,19 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Make sure all the model weights have the same numerical options. quantizer = model_weight._get_quantizer() assert block_len == quantizer.block_len, ( - f"All model weights must have the same block_len, " + "All model weights must have the same block_len, " f"expected {block_len} but got {quantizer.block_len}" ) assert fp8_dtype == quantizer.dtype, ( - f"All model weights must have the same fp8 dtype, " + "All model weights must have the same fp8 dtype, " f"expected {fp8_dtype} but got {quantizer.dtype}" ) assert force_pow_2_scales == quantizer.force_pow_2_scales, ( - f"All model weights must have the same force_pow_2_scales, " + "All model weights must have the same force_pow_2_scales, " f"expected {force_pow_2_scales} but got {quantizer.force_pow_2_scales}" ) assert amax_epsilon == quantizer.amax_epsilon, ( - f"All model weights must have the same amax_epsilon, " + "All model weights must have the same amax_epsilon, " f"expected {amax_epsilon} but got {quantizer.amax_epsilon}" ) @@ -433,20 +431,18 @@ def _cast_master_weights_to_fp8_blockwise_scaling( amax = packed_amaxes[cu_amax_sizes[i] : cu_amax_sizes[i + 1]].reshape(scale_shape) scale = torch.empty(scale_shape, dtype=torch.float32, device=device) scale_inv = model_weight._rowwise_scale_inv - assert len(scale_shape) == 2, ( - f"scale_shape must be 2D, got {len(scale_shape)}D shape {scale_shape}" - ) - assert len(scale_inv.shape) == 2, ( - f"scale_inv must be 2D, got {len(scale_inv.shape)}D shape {scale_inv.shape}" - ) - assert scale_inv.shape[0] == scale_shape[0], ( - f"scale_inv dim 0 mismatch: scale_inv.shape={scale_inv.shape}, " - f"scale_shape={scale_shape}" - ) - assert scale_inv.shape[1] == scale_shape[1], ( - f"scale_inv dim 1 mismatch: scale_inv.shape={scale_inv.shape}, " - f"scale_shape={scale_shape}" - ) + assert ( + len(scale_shape) == 2 + ), f"scale_shape must be 2D, got {len(scale_shape)}D shape {scale_shape}" + assert ( + len(scale_inv.shape) == 2 + ), f"scale_inv must be 2D, got {len(scale_inv.shape)}D shape {scale_inv.shape}" + assert ( + scale_inv.shape[0] == scale_shape[0] + ), f"scale_inv dim 0 mismatch: scale_inv.shape={scale_inv.shape}, scale_shape={scale_shape}" + assert ( + scale_inv.shape[1] == scale_shape[1] + ), f"scale_inv dim 1 mismatch: scale_inv.shape={scale_inv.shape}, scale_shape={scale_shape}" amaxes.append(amax) scales.append(scale) @@ -455,7 +451,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: assert len(model_weight.shape) == 2, ( - f"model_weight must be 2D for blockwise scaling, " + "model_weight must be 2D for blockwise scaling, " f"got {len(model_weight.shape)}D shape {model_weight.shape}" ) h, w = model_weight.shape @@ -509,7 +505,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( if not use_fsdp_shard_model_weights: model_weight_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] assert len(model_weight.shape) == 2, ( - f"model_weight must be 2D for blockwise scaling partial cast, " + "model_weight must be 2D for blockwise scaling partial cast, " f"got {len(model_weight.shape)}D shape {model_weight.shape}" ) h, w = model_weight.shape @@ -544,15 +540,13 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( cu_colwise_amax_sizes = [0] for model_weight, _, _, _ in params: rowwise_shape = model_weight._rowwise_scale_inv.shape - assert len(rowwise_shape) == 2, ( - f"rowwise_scale_inv must be 2D, " - f"got {len(rowwise_shape)}D shape {rowwise_shape}" - ) + assert ( + len(rowwise_shape) == 2 + ), f"rowwise_scale_inv must be 2D, got {len(rowwise_shape)}D shape {rowwise_shape}" colwise_shape = model_weight._columnwise_scale_inv.shape - assert len(colwise_shape) == 2, ( - f"columnwise_scale_inv must be 2D, " - f"got {len(colwise_shape)}D shape {colwise_shape}" - ) + assert ( + len(colwise_shape) == 2 + ), f"columnwise_scale_inv must be 2D, got {len(colwise_shape)}D shape {colwise_shape}" cu_rowwise_amax_sizes.append( cu_rowwise_amax_sizes[-1] + rowwise_shape[0] * rowwise_shape[1] ) @@ -592,7 +586,7 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( # Compute amax of the master weight and store it in packed_amaxes. if master_weight is not None: assert len(model_weight.shape) == 2, ( - f"model_weight must be 2D for MXFP8 scaling, " + "model_weight must be 2D for MXFP8 scaling, " f"got {len(model_weight.shape)}D shape {model_weight.shape}" ) h, w = model_weight.shape @@ -639,7 +633,7 @@ def _cast_master_weights_to_fp8_mxfp8_scaling( rowwise_fragment = model_weight._rowwise_data.reshape(-1)[start_offset:end_offset] colwise_fragment = model_weight._columnwise_data.reshape(-1)[start_offset:end_offset] assert len(model_weight.shape) == 2, ( - f"model_weight must be 2D for MXFP8 scaling partial cast, " + "model_weight must be 2D for MXFP8 scaling partial cast, " f"got {len(model_weight.shape)}D shape {model_weight.shape}" ) h, w = model_weight.shape diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 1407205adb..295095c01a 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -375,7 +375,7 @@ def __init__( if parallel_attention_mlp: if self.layer_type != "encoder": raise ValueError( - f"parallel_attention requires layer_type='encoder', " + "parallel_attention requires layer_type='encoder', " f"but got layer_type={self.layer_type!r}" ) if self.apply_residual_connection_post_layernorm: @@ -820,8 +820,8 @@ def forward( if self.sequence_parallel and self.seq_length is not None: if hidden_states.shape[0] != self.seq_length // self.tp_size: raise ValueError( - f"Sequence dimension must be split across TP group when using " - f"sequence parallel. Expected hidden_states.shape[0] to be " + "Sequence dimension must be split across TP group when using " + "sequence parallel. Expected hidden_states.shape[0] to be " f"{self.seq_length // self.tp_size} " f"(seq_length={self.seq_length} // tp_size={self.tp_size}), " f"but got {hidden_states.shape[0]}." @@ -830,24 +830,21 @@ def forward( if ( "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" ) and attention_mask is not None: - if not all( - attention_mask[i].dtype == torch.bool for i in range(len(attention_mask)) - ): + if not all(attention_mask[i].dtype == torch.bool for i in range(len(attention_mask))): non_bool_dtypes = [ (i, attention_mask[i].dtype) for i in range(len(attention_mask)) if attention_mask[i].dtype != torch.bool ] raise TypeError( - f"Attention mask must be a boolean tensor or a list/tuple of boolean " + "Attention mask must be a boolean tensor or a list/tuple of boolean " f"tensors, but found non-bool dtypes at indices: {non_bool_dtypes}" ) if ( "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" ) and enc_dec_attn_mask is not None: if not all( - enc_dec_attn_mask[i].dtype == torch.bool - for i in range(len(enc_dec_attn_mask)) + enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ): non_bool_dtypes = [ (i, enc_dec_attn_mask[i].dtype) @@ -855,7 +852,7 @@ def forward( if enc_dec_attn_mask[i].dtype != torch.bool ] raise TypeError( - f"Encoder-decoder attention mask must be boolean tensor(s), " + "Encoder-decoder attention mask must be boolean tensor(s), " f"but found non-bool dtypes at indices: {non_bool_dtypes}" ) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9607bea96f..16d74bcd6d 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -376,9 +376,7 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: required for tensor/model and sequence parallel. """ if not callable(get_rng_tracker): - raise TypeError( - f"get_rng_tracker must be callable, got {type(get_rng_tracker).__name__}" - ) + raise TypeError(f"get_rng_tracker must be callable, got {type(get_rng_tracker).__name__}") rng_tracker = None try: @@ -387,9 +385,7 @@ def validate_rng_states_func(get_rng_tracker: Callable) -> None: raise RuntimeError("Cannot call get_rng_tracker function") from e for method_name in ("get_states", "set_states", "fork"): - if not hasattr(rng_tracker, method_name) or not callable( - getattr(rng_tracker, method_name) - ): + if not hasattr(rng_tracker, method_name) or not callable(getattr(rng_tracker, method_name)): raise TypeError( f"rng_tracker object ({type(rng_tracker).__name__}) does not have " f"a valid callable method '{method_name}'. " @@ -464,7 +460,7 @@ def assert_dim_for_all_gather( if with_all_gather: if not quantizer.is_quantizable(tensor): raise ValueError( - f"All-gather requires a quantizable tensor for quantizer" + "All-gather requires a quantizable tensor for quantizer" f" {quantizer.__class__.__name__}, but got tensor with" f" shape={list(tensor.shape)} and dtype={tensor.dtype}" ) @@ -767,9 +763,7 @@ def torch_dtype_to_np_typestr(self): ret = _torch_dtype_to_np_typestr_dict.get(self.dtype) if ret is None: supported = ", ".join(str(d) for d in _torch_dtype_to_np_typestr_dict) - raise TypeError( - f"Unsupported dtype: {self.dtype}. Supported dtypes: {supported}" - ) + raise TypeError(f"Unsupported dtype: {self.dtype}. Supported dtypes: {supported}") return ret