diff --git a/test/inductor/test_op_dtype_prop.py b/test/inductor/test_op_dtype_prop.py index 6f7eec601666b..5e7e41e22b034 100644 --- a/test/inductor/test_op_dtype_prop.py +++ b/test/inductor/test_op_dtype_prop.py @@ -205,6 +205,9 @@ def test_dtype_aware_codegen(self, op_name: str, load_upcast_to_fp32, input_dtyp triton_op_name_overrides = { "round": "nearbyint", } + # ROCm uses fast_tahnf for everything input types that are not float64 + if torch.version.hip and input_dtype != torch.float64: + triton_op_name_overrides["tanh"] = "fast_tanhf" override = triton_op_name_overrides.get(op_name) triton_op_name = override if override is not None else torch_op_name diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index d55fcf4df449d..8e8ba71ddd7be 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1315,7 +1315,16 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - if torch.version.hip and get_triton_version() > (3, 2): + cse_var = V.kernel.cse.varname_map.get(x) + if cse_var and hasattr(cse_var, "dtype"): + dtype = cse_var.dtype + else: + dtype = None + if ( + torch.version.hip + and get_triton_version() > (3, 2) + and dtype != torch.float64 + ): # On ROCm, use fast_tanhf depending on Triton version # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ return f"libdevice.fast_tanhf({x})"