From a53d6cfbc968a8f7bf258c85a642c7a2ad986580 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 20 Nov 2025 23:30:21 +0000 Subject: [PATCH 1/4] Fallback to tanh when dtype is FP64. (cherry picked from commit cea11b3831d848478f84484e6d00633bea5b18e6) --- torch/_inductor/codegen/triton.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d55fcf4df449d..40ae154dbe666 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1315,7 +1315,12 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - if torch.version.hip and get_triton_version() > (3, 2): + dtype = V.kernel.cse.varname_map.get(x).dtype + if ( + torch.version.hip + and get_triton_version() > (3, 4) + and dtype != torch.float64 + ): # On ROCm, use fast_tanhf depending on Triton version # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ return f"libdevice.fast_tanhf({x})" From ad780aa34842267e1447fbb28a2bb3936a6e5127 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 21 Nov 2025 19:12:45 +0000 Subject: [PATCH 2/4] Handle case where there is no dtype. (cherry picked from commit 215a05da37f63d723a3efb04feb1e78f79d56a80) --- torch/_inductor/codegen/triton.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 40ae154dbe666..9ac961d2d7277 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1315,7 +1315,11 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - dtype = V.kernel.cse.varname_map.get(x).dtype + cse_var = V.kernel.cse.varname_map.get(x) + if cse_var and hasattr(cse_var, "dtype"): + dtype = cse_var.dtype + else: + dtype = None if ( torch.version.hip and get_triton_version() > (3, 4) From 495a6dea2ca22230f08144e2f6d2d695fb3b9cdb Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 27 Nov 2025 00:27:08 +0000 Subject: [PATCH 3/4] Fix UT. (cherry picked from commit c90b7ee867acfcdb5090d80e0c55ceb7d847f680) --- test/inductor/test_op_dtype_prop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/inductor/test_op_dtype_prop.py b/test/inductor/test_op_dtype_prop.py index 6f7eec601666b..5e7e41e22b034 100644 --- a/test/inductor/test_op_dtype_prop.py +++ b/test/inductor/test_op_dtype_prop.py @@ -205,6 +205,9 @@ def test_dtype_aware_codegen(self, op_name: str, load_upcast_to_fp32, input_dtyp triton_op_name_overrides = { "round": "nearbyint", } + # ROCm uses fast_tahnf for everything input types that are not float64 + if torch.version.hip and input_dtype != torch.float64: + triton_op_name_overrides["tanh"] = "fast_tanhf" override = triton_op_name_overrides.get(op_name) triton_op_name = override if override is not None else torch_op_name From 5a6d97ed21b12564251f03d71a008a390fc8b402 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 27 Nov 2025 00:52:09 +0000 Subject: [PATCH 4/4] Correct Triton backport version. --- torch/_inductor/codegen/triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 9ac961d2d7277..8e8ba71ddd7be 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1322,7 +1322,7 @@ def tanh(x): dtype = None if ( torch.version.hip - and get_triton_version() > (3, 4) + and get_triton_version() > (3, 2) and dtype != torch.float64 ): # On ROCm, use fast_tanhf depending on Triton version