From 83025fc8c0fa21e1737cee6b20fd36010ee63d92 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 5 May 2026 18:01:03 +0000 Subject: [PATCH 1/5] Use fast unfused cast mxfp8 kernels by default Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/quantize_mxfp8.cuh | 2 +- .../cast/mxfp8/specialized/quantize_mxfp8.cuh | 20 ++++--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a0ae7dde82..76161befbe 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -644,7 +644,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES) { + !WITH_GEMM_SWIZZLED_SCALES && (scaling_type != ScalingType::COLWISE)) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; diff --git a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh index 41e62ac319..9459f0273a 100644 --- a/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/specialized/quantize_mxfp8.cuh @@ -91,18 +91,6 @@ __device__ __forceinline__ e8m0_t to_e8m0(IType amax) { #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // anonymous namespace -inline bool is_cast_only_enabled() { - static bool enabled = []() { - const char *env = std::getenv("ENABLE_CAST_ONLY"); - return env != nullptr && (env[0] == '1'); - }(); - return enabled; - - // // FIXME: when finish debugging, remove this - // const char* env = std::getenv("ENABLE_CAST_ONLY"); - // return env != nullptr && (env[0] == '1'); -} - template inline bool hasSpec() { return false; @@ -112,19 +100,19 @@ inline bool hasSpec() { // OType could be [fp8e5m2, fp8e4m3] template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template <> inline bool hasSpec() { - return is_cast_only_enabled(); + return true; } template From e926c9a6542ce0c8469cd73b4a99033221a8983b Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Tue, 5 May 2026 18:07:57 +0000 Subject: [PATCH 2/5] Removed dead code Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 76161befbe..bdd063f4f8 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -664,10 +664,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } - case ScalingType::COLWISE: { - NVTE_WARN("Colwise scaling will fallback to original kernel."); - break; - } case ScalingType::BIDIMENSIONAL: { using traits = specialized::CastTraits; auto kernel = specialized::quantize_mxfp8_kernel_cast_only; From 9ae6664d85a4d076216ea658965afdd4bfbe18e2 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 6 May 2026 12:42:23 +0000 Subject: [PATCH 3/5] Use fast kernel for full 32-element chunks only Signed-off-by: Oleg Goncharov --- .../common/cast/mxfp8/quantize_mxfp8.cuh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index bdd063f4f8..91560e8698 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -643,8 +643,18 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + // The specialized rowwise cast-only kernel vectorizes full 32-element chunks. + // Shapes with a partial row tail (for example, N=48) must use the generic kernel, + // otherwise the last chunk reads/writes past the logical end of the row. + const bool is_full_rowwise_chunk = + (cols % specialized::CastTraits::chunkElems == 0); + + const bool scaling_type_has_specialized_support = + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || + (scaling_type == ScalingType::BIDIMENSIONAL); + if (specialized::hasSpec() && - !WITH_GEMM_SWIZZLED_SCALES && (scaling_type != ScalingType::COLWISE)) { + !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { switch (scaling_type) { case ScalingType::ROWWISE: { using traits = specialized::CastTraits; From 717038527c6d2dd5b64ca450b606538f38fa8ef3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 May 2026 12:43:22 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 91560e8698..a0ac98a028 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -647,11 +647,11 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // Shapes with a partial row tail (for example, N=48) must use the generic kernel, // otherwise the last chunk reads/writes past the logical end of the row. const bool is_full_rowwise_chunk = - (cols % specialized::CastTraits::chunkElems == 0); + (cols % specialized::CastTraits::chunkElems == 0); - const bool scaling_type_has_specialized_support = - (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || - (scaling_type == ScalingType::BIDIMENSIONAL); + const bool scaling_type_has_specialized_support = + (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) || + (scaling_type == ScalingType::BIDIMENSIONAL); if (specialized::hasSpec() && !WITH_GEMM_SWIZZLED_SCALES && scaling_type_has_specialized_support) { From 9956d3a2f0310d94fcedb5f1c595f3f889798d04 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Thu, 7 May 2026 14:25:24 +0000 Subject: [PATCH 5/5] Fix Signed-off-by: Oleg Goncharov --- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index a0ac98a028..f42ee6cdd3 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -643,11 +643,10 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - // The specialized rowwise cast-only kernel vectorizes full 32-element chunks. + // The specialized rowwise cast-only kernel vectorizes full 128-element chunks. // Shapes with a partial row tail (for example, N=48) must use the generic kernel, // otherwise the last chunk reads/writes past the logical end of the row. - const bool is_full_rowwise_chunk = - (cols % specialized::CastTraits::chunkElems == 0); + const bool is_full_rowwise_chunk = (cols % 128 == 0); const bool scaling_type_has_specialized_support = (scaling_type == ScalingType::ROWWISE && is_full_rowwise_chunk) ||