diff --git a/test/inductor/test_amd_mfma_nonkdim_config.py b/test/inductor/test_amd_mfma_nonkdim_config.py new file mode 100644 index 0000000000000..db758b51f9f21 --- /dev/null +++ b/test/inductor/test_amd_mfma_nonkdim_config.py @@ -0,0 +1,97 @@ +# Owner(s): ["module: inductor"] +"""Tests for ``torch._inductor.config.rocm.mfma_nonkdim``. + +The config reads the env var ``TORCHINDUCTOR_MFMA_NONKDIM`` at module +import time, so we test the helper functions in +``torch._inductor.template_heuristics.triton`` by patching the config +value directly (`torch._inductor.config.patch`) rather than reloading +the module. We also exercise the env path with a subprocess for the +import-time read. +""" +import os +import subprocess +import sys +import unittest + +import torch +import torch._inductor.config as inductor_config +from torch._inductor.template_heuristics.triton import ( + _amd_mm_nonkdim_autotune_choices, + _amd_mm_nonkdim_default, +) +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestAmdMfmaNonkdimConfig(TestCase): + def test_unset_matches_upstream(self): + # Patch to None so we mimic the "env unset" case regardless of the + # current process environment. + with inductor_config.patch({"rocm.mfma_nonkdim": None}): + self.assertEqual(_amd_mm_nonkdim_default(), 16) + self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [0, 16]) + + def test_force_16_matches_upstream(self): + with inductor_config.patch({"rocm.mfma_nonkdim": 16}): + self.assertEqual(_amd_mm_nonkdim_default(), 16) + self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [16]) + + def test_force_32(self): + with inductor_config.patch({"rocm.mfma_nonkdim": 32}): + self.assertEqual(_amd_mm_nonkdim_default(), 32) + self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [32]) + + def test_auto_extends_list(self): + with inductor_config.patch({"rocm.mfma_nonkdim": "auto"}): + # default stays 16 in "auto" mode (matches upstream ROCmGemmConfig default) + self.assertEqual(_amd_mm_nonkdim_default(), 16) + self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [0, 16, 32]) + + def test_force_zero(self): + with inductor_config.patch({"rocm.mfma_nonkdim": 0}): + self.assertEqual(_amd_mm_nonkdim_default(), 0) + self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [0]) + + def _spawn_env_probe(self, env_value): + """Spawn a fresh python process with TORCHINDUCTOR_MFMA_NONKDIM set + to ``env_value`` and read back the parsed config value. + + This validates the import-time env parsing in config.py (which the + in-process patch tests can't exercise because the value was set at + import). + """ + env = os.environ.copy() + env.pop("TORCHINDUCTOR_MFMA_NONKDIM", None) + if env_value is not None: + env["TORCHINDUCTOR_MFMA_NONKDIM"] = env_value + code = ( + "import torch, json; " + "import torch._inductor.config as c; " + "v = c.rocm.mfma_nonkdim; " + "print(json.dumps({'type': type(v).__name__, 'value': v}))" + ) + out = subprocess.check_output( + [sys.executable, "-c", code], env=env + ).decode().strip() + import json + return json.loads(out) + + def test_env_parsing_in_fresh_process(self): + for env_value, expected in [ + (None, None), + ("16", 16), + ("32", 32), + ("0", 0), + ("auto", "auto"), + ("AUTO", "auto"), + ("notanint", None), + ("", None), + ]: + r = self._spawn_env_probe(env_value) + self.assertEqual( + r["value"], expected, + f"env=`{env_value}` got {r['value']} expected {expected}", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 1e6753cb66d6d..5a37a51e173d7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -2276,6 +2276,34 @@ class rocm: # Flag to print register and LDS usage during compilation print_kernel_resource_usage = False + # Optional override for matrix_instr_nonkdim in the AMD MM Triton template. + # None (default = env unset) leaves upstream behaviour unchanged: the + # autotune sweep tries matrix_instr_nonkdim in [0, 16] and + # ROCmGemmConfig defaults to 16. + # + # Set via the env var TORCHINDUCTOR_MFMA_NONKDIM. Recognised values: + # unset upstream default (= None) + # "0"/"16"/"32" force a single value; autotune sweep collapses to + # [value] and ROCmGemmConfig defaults to that value + # "auto" extend the autotune sweep to [0, 16, 32]; the default + # ROCmGemmConfig.matrix_instr_nonkdim stays 16 and + # autotune picks + # + # mfma_32x32x*_bf16 is only emitted when the value 32 is included in + # the sweep (or forced). This is workload-specific tuning; "auto" is + # the safest opt-in for shapes where mfma_32 might win, "32" forces + # it on, and "16" is the conservative upstream behaviour. Ignored on + # non-ROCm backends. + mfma_nonkdim: int | str | None = ( + "auto" + if os.environ.get("TORCHINDUCTOR_MFMA_NONKDIM", "").lower() == "auto" + else ( + int(os.environ["TORCHINDUCTOR_MFMA_NONKDIM"]) + if os.environ.get("TORCHINDUCTOR_MFMA_NONKDIM", "").lstrip("-").isdigit() + else None + ) + ) + # Path to ROCm installation, if None, use env variable ROCM_HOME. # In fbcode see triton/fb/TARGETS for how ROCM_HOME gets set. rocm_home: str | None = None diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 89999fc2f3636..a69188c4d6668 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -168,13 +168,39 @@ class FlexDecodeConfig: # ROCm classes +def _amd_mm_nonkdim_default() -> int: + """Default ``matrix_instr_nonkdim`` for ROCm MM Triton configs. + + Reads ``config.rocm.mfma_nonkdim``; falls back to 16 (the upstream + literal) when None or "auto". See `torch._inductor.config.rocm`. + """ + v = config.rocm.mfma_nonkdim + return v if isinstance(v, int) else 16 + + +def _amd_mm_nonkdim_autotune_choices() -> list[int]: + """``matrix_instr_nonkdim`` values the AMD MM autotune sweep should try. + + Default [0, 16] (upstream); a forced int collapses to a single value; + "auto" extends to [0, 16, 32]. + """ + v = config.rocm.mfma_nonkdim + if isinstance(v, int): + return [v] + if v == "auto": + return [0, 16, 32] + return [0, 16] + + @dataclasses.dataclass class ROCmGemmConfig(GemmConfig): """ ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs """ - matrix_instr_nonkdim: int = 16 + matrix_instr_nonkdim: int = dataclasses.field( + default_factory=_amd_mm_nonkdim_default + ) waves_per_eu: int = 0 kpack: int = 2 @@ -1545,7 +1571,7 @@ def __init__(self) -> None: for num_stages in [1, self.default_num_stages] for num_warps in [4, 8] for group_m in [4, 8, 16] - for matrix_instr_nonkdim in [0, 16] + for matrix_instr_nonkdim in _amd_mm_nonkdim_autotune_choices() for waves_per_eu in [0, 2] for kpack in [2] ]