diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a0ae7dde82..f42ee6cdd3 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -643,8 +643,17 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + // The specialized rowwise cast-only kernel vectorizes full 128-element chunks. + // Shapes with a partial row tail (for example, N=48) must use the generic kernel, + // otherwise the last chunk reads/writes past the logical end of the row. + const bool is_full_rowwise_chunk = (cols % 128 == 0); + + const bool scaling_type_has_specialized_support = + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || + (scaling_type == ScalingType::BIDIMENSIONAL); + if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { + !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; @@ -664,10 +673,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } - case ScalingType::COLWISE: { - NVTE_WARN("Colwise scaling will fallback to original kernel."); - break; - } case ScalingType::BIDIMENSIONAL: { using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; diff --git a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh index 41e62ac319..9459f0273a 100644 --- a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh @@ -91,18 +91,6 @@ __device__ __forceinline__ e8m0_t to_e8m0(IType amax) { #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // anonymous namespace -inline bool is_cast_only_enabled() { - static bool enabled = []() { - const char *env = std::getenv("ENABLE_CAST_ONLY"); - return env != nullptr && (env[0] == '1'); - }(); - return enabled; - - // // FIXME: when finish debugging, remove this - // const char* env = std::getenv("ENABLE_CAST_ONLY"); - // return env != nullptr && (env[0] == '1'); -} - template inline bool hasSpec() { return false; @@ -112,19 +100,19 @@ inline bool hasSpec() { // OType could be [fp8e5m2, fp8e4m3] template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template