Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions test/inductor/test_amd_mfma_nonkdim_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Owner(s): ["module: inductor"]
"""Tests for ``torch._inductor.config.rocm.mfma_nonkdim``.

The config reads the env var ``TORCHINDUCTOR_MFMA_NONKDIM`` at module
import time, so we test the helper functions in
``torch._inductor.template_heuristics.triton`` by patching the config
value directly (`torch._inductor.config.patch`) rather than reloading
the module. We also exercise the env path with a subprocess for the
import-time read.
"""
import os
import subprocess
import sys
import unittest

import torch
import torch._inductor.config as inductor_config
from torch._inductor.template_heuristics.triton import (
_amd_mm_nonkdim_autotune_choices,
_amd_mm_nonkdim_default,
)
from torch.testing._internal.common_utils import run_tests, TestCase


class TestAmdMfmaNonkdimConfig(TestCase):
def test_unset_matches_upstream(self):
# Patch to None so we mimic the "env unset" case regardless of the
# current process environment.
with inductor_config.patch({"rocm.mfma_nonkdim": None}):
self.assertEqual(_amd_mm_nonkdim_default(), 16)
self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [0, 16])

def test_force_16_matches_upstream(self):
with inductor_config.patch({"rocm.mfma_nonkdim": 16}):
self.assertEqual(_amd_mm_nonkdim_default(), 16)
self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [16])

def test_force_32(self):
with inductor_config.patch({"rocm.mfma_nonkdim": 32}):
self.assertEqual(_amd_mm_nonkdim_default(), 32)
self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [32])

def test_auto_extends_list(self):
with inductor_config.patch({"rocm.mfma_nonkdim": "auto"}):
# default stays 16 in "auto" mode (matches upstream ROCmGemmConfig default)
self.assertEqual(_amd_mm_nonkdim_default(), 16)
self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [0, 16, 32])

def test_force_zero(self):
with inductor_config.patch({"rocm.mfma_nonkdim": 0}):
self.assertEqual(_amd_mm_nonkdim_default(), 0)
self.assertEqual(_amd_mm_nonkdim_autotune_choices(), [0])

def _spawn_env_probe(self, env_value):
"""Spawn a fresh python process with TORCHINDUCTOR_MFMA_NONKDIM set
to ``env_value`` and read back the parsed config value.

This validates the import-time env parsing in config.py (which the
in-process patch tests can't exercise because the value was set at
import).
"""
env = os.environ.copy()
env.pop("TORCHINDUCTOR_MFMA_NONKDIM", None)
if env_value is not None:
env["TORCHINDUCTOR_MFMA_NONKDIM"] = env_value
code = (
"import torch, json; "
"import torch._inductor.config as c; "
"v = c.rocm.mfma_nonkdim; "
"print(json.dumps({'type': type(v).__name__, 'value': v}))"
)
out = subprocess.check_output(
[sys.executable, "-c", code], env=env
).decode().strip()
import json
return json.loads(out)

def test_env_parsing_in_fresh_process(self):
for env_value, expected in [
(None, None),
("16", 16),
("32", 32),
("0", 0),
("auto", "auto"),
("AUTO", "auto"),
("notanint", None),
("", None),
]:
r = self._spawn_env_probe(env_value)
self.assertEqual(
r["value"], expected,
f"env=`{env_value}` got {r['value']} expected {expected}",
)


if __name__ == "__main__":
run_tests()
28 changes: 28 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2276,6 +2276,34 @@ class rocm:
# Flag to print register and LDS usage during compilation
print_kernel_resource_usage = False

# Optional override for matrix_instr_nonkdim in the AMD MM Triton template.
# None (default = env unset) leaves upstream behaviour unchanged: the
# autotune sweep tries matrix_instr_nonkdim in [0, 16] and
# ROCmGemmConfig defaults to 16.
#
# Set via the env var TORCHINDUCTOR_MFMA_NONKDIM. Recognised values:
# unset upstream default (= None)
# "0"/"16"/"32" force a single value; autotune sweep collapses to
# [value] and ROCmGemmConfig defaults to that value
# "auto" extend the autotune sweep to [0, 16, 32]; the default
# ROCmGemmConfig.matrix_instr_nonkdim stays 16 and
# autotune picks
#
# mfma_32x32x*_bf16 is only emitted when the value 32 is included in
# the sweep (or forced). This is workload-specific tuning; "auto" is
# the safest opt-in for shapes where mfma_32 might win, "32" forces
# it on, and "16" is the conservative upstream behaviour. Ignored on
# non-ROCm backends.
mfma_nonkdim: int | str | None = (
"auto"
if os.environ.get("TORCHINDUCTOR_MFMA_NONKDIM", "").lower() == "auto"
else (
int(os.environ["TORCHINDUCTOR_MFMA_NONKDIM"])
if os.environ.get("TORCHINDUCTOR_MFMA_NONKDIM", "").lstrip("-").isdigit()
else None
)
)

# Path to ROCm installation, if None, use env variable ROCM_HOME.
# In fbcode see triton/fb/TARGETS for how ROCM_HOME gets set.
rocm_home: str | None = None
Expand Down
30 changes: 28 additions & 2 deletions torch/_inductor/template_heuristics/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,39 @@ class FlexDecodeConfig:


# ROCm classes
def _amd_mm_nonkdim_default() -> int:
"""Default ``matrix_instr_nonkdim`` for ROCm MM Triton configs.

Reads ``config.rocm.mfma_nonkdim``; falls back to 16 (the upstream
literal) when None or "auto". See `torch._inductor.config.rocm`.
"""
v = config.rocm.mfma_nonkdim
return v if isinstance(v, int) else 16


def _amd_mm_nonkdim_autotune_choices() -> list[int]:
"""``matrix_instr_nonkdim`` values the AMD MM autotune sweep should try.

Default [0, 16] (upstream); a forced int collapses to a single value;
"auto" extends to [0, 16, 32].
"""
v = config.rocm.mfma_nonkdim
if isinstance(v, int):
return [v]
if v == "auto":
return [0, 16, 32]
return [0, 16]


@dataclasses.dataclass
class ROCmGemmConfig(GemmConfig):
"""
ROCm subclass for GEMMs, with AMD backend specific tuneable kernargs
"""

matrix_instr_nonkdim: int = 16
matrix_instr_nonkdim: int = dataclasses.field(
default_factory=_amd_mm_nonkdim_default
)
waves_per_eu: int = 0
kpack: int = 2

Expand Down Expand Up @@ -1545,7 +1571,7 @@ def __init__(self) -> None:
for num_stages in [1, self.default_num_stages]
for num_warps in [4, 8]
for group_m in [4, 8, 16]
for matrix_instr_nonkdim in [0, 16]
for matrix_instr_nonkdim in _amd_mm_nonkdim_autotune_choices()
for waves_per_eu in [0, 2]
for kpack in [2]
]
Expand Down