diff --git a/.automation_scripts/parse_xml_results.py b/.automation_scripts/parse_xml_results.py
new file mode 100644
index 0000000000000..7db2e1ce9233c
--- /dev/null
+++ b/.automation_scripts/parse_xml_results.py
@@ -0,0 +1,178 @@
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+# Backends list
+BACKENDS_LIST = [
+ "dist-gloo",
+ "dist-nccl"
+]
+
+TARGET_WORKFLOW = "--rerun-disabled-tests"
+
+def get_job_id(report: Path) -> int:
+ # [Job id in artifacts]
+ # Retrieve the job id from the report path. In our GHA workflows, we append
+ # the job id to the end of the report name, so `report` looks like:
+ # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
+ # and we want to get `5596745227` out of it.
+ try:
+ return int(report.parts[0].rpartition("_")[2])
+ except ValueError:
+ return -1
+
+def is_rerun_disabled_tests(root: ET.ElementTree) -> bool:
+ """
+ Check if the test report is coming from rerun_disabled_tests workflow
+ """
+ skipped = root.find(".//*skipped")
+ # Need to check against None here, if not skipped doesn't work as expected
+ if skipped is None:
+ return False
+
+ message = skipped.attrib.get("message", "")
+ return TARGET_WORKFLOW in message or "num_red" in message
+
+def parse_xml_report(
+ tag: str,
+ report: Path,
+ workflow_id: int,
+ workflow_run_attempt: int,
+ work_flow_name: str
+) -> Dict[Tuple[str], Dict[str, Any]]:
+ """Convert a test report xml file into a JSON-serializable list of test cases."""
+ print(f"Parsing {tag}s for test report: {report}")
+
+ job_id = get_job_id(report)
+ print(f"Found job id: {job_id}")
+
+ test_cases: Dict[Tuple[str], Dict[str, Any]] = {}
+
+ root = ET.parse(report)
+ # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops
+ # includes skipped messages multiple times (50 times by default). This slows down
+ # this script too much (O(n)) because it tries to gather all the stats. This should
+ # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun
+ # disabled test is only few MB, but will balloon up to a much bigger XML file after
+ # extracting from a dozen to few hundred MB
+ if is_rerun_disabled_tests(root):
+ return test_cases
+
+ for test_case in root.iter(tag):
+ case = process_xml_element(test_case)
+ if tag == 'testcase':
+ case["workflow_id"] = workflow_id
+ case["workflow_run_attempt"] = workflow_run_attempt
+ case["job_id"] = job_id
+ case["work_flow_name"] = work_flow_name
+
+ # [invoking file]
+ # The name of the file that the test is located in is not necessarily
+ # the same as the name of the file that invoked the test.
+ # For example, `test_jit.py` calls into multiple other test files (e.g.
+ # jit/test_dce.py). For sharding/test selection purposes, we want to
+ # record the file that invoked the test.
+ #
+ # To do this, we leverage an implementation detail of how we write out
+ # tests (https://bit.ly/3ajEV1M), which is that reports are created
+ # under a folder with the same name as the invoking file.
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["work_flow_name"] ) ] = case
+ elif tag == 'testsuite':
+ case["work_flow_name"] = work_flow_name
+ case["invoking_xml"] = report.name
+ case["running_time_xml"] = case["time"]
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+
+ test_cases[ ( case["invoking_file"], case["invoking_xml"], case["work_flow_name"] ) ] = case
+
+ return test_cases
+
+def process_xml_element(element: ET.Element) -> Dict[str, Any]:
+ """Convert a test suite element into a JSON-serializable dict."""
+ ret: Dict[str, Any] = {}
+
+ # Convert attributes directly into dict elements.
+ # e.g.
+ #
+ # becomes:
+ # {"name": "test_foo", "classname": "test_bar"}
+ ret.update(element.attrib)
+
+ # The XML format encodes all values as strings. Convert to ints/floats if
+ # possible to make aggregation possible in Rockset.
+ for k, v in ret.items():
+ try:
+ ret[k] = int(v)
+ except ValueError:
+ pass
+ try:
+ ret[k] = float(v)
+ except ValueError:
+ pass
+
+ # Convert inner and outer text into special dict elements.
+ # e.g.
+ # my_inner_text my_tail
+ # becomes:
+ # {"text": "my_inner_text", "tail": " my_tail"}
+ if element.text and element.text.strip():
+ ret["text"] = element.text
+ if element.tail and element.tail.strip():
+ ret["tail"] = element.tail
+
+ # Convert child elements recursively, placing them at a key:
+ # e.g.
+ #
+ # hello
+ # world
+ # another
+ #
+ # becomes
+ # {
+ # "foo": [{"text": "hello"}, {"text": "world"}],
+ # "bar": {"text": "another"}
+ # }
+ for child in element:
+ if child.tag not in ret:
+ ret[child.tag] = process_xml_element(child)
+ else:
+ # If there are multiple tags with the same name, they should be
+ # coalesced into a list.
+ if not isinstance(ret[child.tag], list):
+ ret[child.tag] = [ret[child.tag]]
+ ret[child.tag].append(process_xml_element(child))
+ return ret
\ No newline at end of file
diff --git a/.automation_scripts/run_pytorch_unit_tests.py b/.automation_scripts/run_pytorch_unit_tests.py
new file mode 100644
index 0000000000000..514afd19624c3
--- /dev/null
+++ b/.automation_scripts/run_pytorch_unit_tests.py
@@ -0,0 +1,518 @@
+#!/usr/bin/env python3
+
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import argparse
+import os
+import shutil
+import subprocess
+from subprocess import STDOUT, CalledProcessError
+
+from collections import namedtuple
+from datetime import datetime
+from pathlib import Path
+from parse_xml_results import (
+ parse_xml_report
+)
+from pprint import pprint
+from typing import Any, Dict, List
+
+# unit test status list
+UT_STATUS_LIST = [
+ "PASSED",
+ "MISSED",
+ "SKIPPED",
+ "FAILED",
+ "XFAILED",
+ "ERROR"
+]
+
+DEFAULT_CORE_TESTS = [
+ "test_nn",
+ "test_torch",
+ "test_cuda",
+ "test_ops",
+ "test_unary_ufuncs",
+ "test_autograd",
+ "inductor/test_torchinductor"
+]
+
+DISTRIBUTED_CORE_TESTS = [
+ "distributed/test_c10d_common",
+ "distributed/test_c10d_nccl",
+ "distributed/test_distributed_spawn"
+]
+
+CONSOLIDATED_LOG_FILE_NAME="pytorch_unit_tests.log"
+
+def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, workflow_name, path="."):
+ test_cases = {}
+ items_list = os.listdir(path)
+ for dir in items_list:
+ new_dir = path + '/' + dir + '/'
+ if os.path.isdir(new_dir):
+ for xml_report in Path(new_dir).glob("**/*.xml"):
+ test_cases.update(
+ parse_xml_report(
+ tag,
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ workflow_name
+ )
+ )
+ return test_cases
+
+def get_test_status(test_case):
+ # In order of priority: S=skipped, F=failure, E=error, P=pass
+ if "skipped" in test_case and test_case["skipped"]:
+ type_message = test_case["skipped"]
+ if type_message.__contains__('type') and type_message['type'] == "pytest.xfail":
+ return "XFAILED"
+ else:
+ return "SKIPPED"
+ elif "failure" in test_case and test_case["failure"]:
+ return "FAILED"
+ elif "error" in test_case and test_case["error"]:
+ return "ERROR"
+ else:
+ return "PASSED"
+
+def get_test_message(test_case, status=None):
+ if status == "SKIPPED":
+ return test_case["skipped"] if "skipped" in test_case else ""
+ elif status == "FAILED":
+ return test_case["failure"] if "failure" in test_case else ""
+ elif status == "ERROR":
+ return test_case["error"] if "error" in test_case else ""
+ else:
+ if "skipped" in test_case:
+ return test_case["skipped"]
+ elif "failure" in test_case:
+ return test_case["failure"]
+ elif "error" in test_case:
+ return test_case["error"]
+ else:
+ return ""
+
+def get_test_file_running_time(test_suite):
+ if test_suite.__contains__('time'):
+ return test_suite["time"]
+ return 0
+
+def get_test_running_time(test_case):
+ if test_case.__contains__('time'):
+ return test_case["time"]
+ return ""
+
+def summarize_xml_files(path, workflow_name):
+ # statistics
+ TOTAL_TEST_NUM = 0
+ TOTAL_PASSED_NUM = 0
+ TOTAL_SKIPPED_NUM = 0
+ TOTAL_XFAIL_NUM = 0
+ TOTAL_FAILED_NUM = 0
+ TOTAL_ERROR_NUM = 0
+ TOTAL_EXECUTION_TIME = 0
+
+ #parse the xml files
+ test_cases = parse_xml_reports_as_dict(-1, -1, 'testcase', workflow_name, path)
+ test_suites = parse_xml_reports_as_dict(-1, -1, 'testsuite', workflow_name, path)
+ test_file_and_status = namedtuple("test_file_and_status", ["file_name", "status"])
+ # results dict
+ res = {}
+ res_item_list = [ "PASSED", "SKIPPED", "XFAILED", "FAILED", "ERROR" ]
+ test_file_items = set()
+ for (k,v) in list(test_suites.items()):
+ file_name = k[0]
+ if not file_name in test_file_items:
+ test_file_items.add(file_name)
+ # initialization
+ for item in res_item_list:
+ temp_item = test_file_and_status(file_name, item)
+ res[temp_item] = {}
+ temp_item_statistics = test_file_and_status(file_name, "STATISTICS")
+ res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0, 'EXECUTION_TIME': 0}
+ test_running_time = get_test_file_running_time(v)
+ res[temp_item_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+ else:
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ test_running_time = get_test_file_running_time(v)
+ res[test_tuple_key_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+
+ for (k,v) in list(test_cases.items()):
+ file_name = k[0]
+ class_name = k[1]
+ test_name = k[2]
+ combined_name = file_name + "::" + class_name + "::" + test_name
+ test_status = get_test_status(v)
+ test_running_time = get_test_running_time(v)
+ test_message = get_test_message(v, test_status)
+ test_info_value = ""
+ test_tuple_key_status = test_file_and_status(file_name, test_status)
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ TOTAL_TEST_NUM += 1
+ res[test_tuple_key_statistics]["TOTAL"] += 1
+ if test_status == "PASSED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["PASSED"] += 1
+ TOTAL_PASSED_NUM += 1
+ elif test_status == "SKIPPED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["SKIPPED"] += 1
+ TOTAL_SKIPPED_NUM += 1
+ elif test_status == "XFAILED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["XFAILED"] += 1
+ TOTAL_XFAIL_NUM += 1
+ elif test_status == "FAILED":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["FAILED"] += 1
+ TOTAL_FAILED_NUM += 1
+ elif test_status == "ERROR":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["ERROR"] += 1
+ TOTAL_ERROR_NUM += 1
+
+ # generate statistics_dict
+ statistics_dict = {}
+ statistics_dict["TOTAL"] = TOTAL_TEST_NUM
+ statistics_dict["PASSED"] = TOTAL_PASSED_NUM
+ statistics_dict["SKIPPED"] = TOTAL_SKIPPED_NUM
+ statistics_dict["XFAILED"] = TOTAL_XFAIL_NUM
+ statistics_dict["FAILED"] = TOTAL_FAILED_NUM
+ statistics_dict["ERROR"] = TOTAL_ERROR_NUM
+ statistics_dict["EXECUTION_TIME"] = TOTAL_EXECUTION_TIME
+ aggregate_item = workflow_name + "_aggregate"
+ total_item = test_file_and_status(aggregate_item, "STATISTICS")
+ res[total_item] = statistics_dict
+
+ return res
+
+def run_command_and_capture_output(cmd):
+ try:
+ print(f"Running command '{cmd}'")
+ with open(CONSOLIDATED_LOG_FILE_PATH, "a+") as output_file:
+ print(f"========================================", file=output_file, flush=True)
+ print(f"[RUN_PYTORCH_UNIT_TESTS] Running command '{cmd}'", file=output_file, flush=True) # send to consolidated file as well
+ print(f"========================================", file=output_file, flush=True)
+ p = subprocess.run(cmd, shell=True, stdout=output_file, stderr=STDOUT, text=True)
+ except CalledProcessError as e:
+ print(f"ERROR: Cmd {cmd} failed with return code: {e.returncode}!")
+
+def run_entire_tests(workflow_name, test_shell_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_entire_tests/"
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_entire_tests/"
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_entire_tests/"
+ # use test.sh for tests execution
+ run_command_and_capture_output(test_shell_path)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ entire_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+ return entire_results_dict
+
+def run_priority_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ default_priority_test_suites = " ".join(DEFAULT_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + default_priority_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ distributed_priority_test_suites = " ".join(DISTRIBUTED_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + distributed_priority_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ priority_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return priority_results_dict
+
+def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src, selected_list):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ default_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + default_selected_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ distributed_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + distributed_selected_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_selected_tests/"
+ inductor_selected_test_suites = ""
+ non_inductor_selected_test_suites = ""
+ for item in selected_list:
+ if "inductor/" in item:
+ inductor_selected_test_suites += item
+ inductor_selected_test_suites += " "
+ else:
+ non_inductor_selected_test_suites += item
+ non_inductor_selected_test_suites += " "
+ if inductor_selected_test_suites != "":
+ inductor_selected_test_suites = inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --include " + inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ if non_inductor_selected_test_suites != "":
+ non_inductor_selected_test_suites = non_inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --inductor --include " + non_inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ selected_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return selected_results_dict
+
+def run_test_and_summarize_results(
+ pytorch_root_dir: str,
+ priority_tests: bool,
+ test_config: List[str],
+ default_list: List[str],
+ distributed_list: List[str],
+ inductor_list: List[str],
+ skip_rerun: bool) -> Dict[str, Any]:
+
+ # copy current environment variables
+ _environ = dict(os.environ)
+
+ # modify path
+ test_shell_path = pytorch_root_dir + "/.ci/pytorch/test.sh"
+ test_run_test_path = pytorch_root_dir + "/test/run_test.py"
+ repo_test_log_folder_path = pytorch_root_dir + "/.automation_logs/"
+ test_reports_src = pytorch_root_dir + "/test/test-reports/"
+ run_test_python_file = pytorch_root_dir + "/test/run_test.py"
+
+ # change directory to pytorch root
+ os.chdir(pytorch_root_dir)
+
+ # all test results dict
+ res_all_tests_dict = {}
+
+ # patterns
+ search_text = "--reruns=2"
+ replace_text = "--reruns=0"
+
+ # create logs folder
+ if not os.path.exists(repo_test_log_folder_path):
+ os.mkdir(repo_test_log_folder_path)
+
+ # Set common environment variables for all scenarios
+ os.environ['CI'] = '1'
+ os.environ['PYTORCH_TEST_WITH_ROCM'] = '1'
+ os.environ['HSA_FORCE_FINE_GRAIN_PCIE'] = '1'
+ os.environ['PYTORCH_TESTING_DEVICE_ONLY_FOR'] = 'cuda'
+ os.environ['CONTINUE_THROUGH_ERROR'] = 'True'
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(search_text, replace_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ # Time stamp
+ current_datetime = datetime.now().strftime("%Y%m%d_%H-%M-%S")
+ print("Current date & time : ", current_datetime)
+ # performed as Job ID
+ str_current_datetime = str(current_datetime)
+ overall_logs_path_current_run = repo_test_log_folder_path + str_current_datetime + "/"
+ os.mkdir(overall_logs_path_current_run)
+
+ global CONSOLIDATED_LOG_FILE_PATH
+ CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME
+
+ # Check multi gpu availability if distributed tests are enabled
+ if ("distributed" in test_config) or len(distributed_list) != 0:
+ check_num_gpus_for_distributed()
+
+ # Install test requirements
+ command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt"
+ run_command_and_capture_output(command)
+
+ # Run entire tests for each workflow
+ if not priority_tests and not default_list and not distributed_list and not inductor_list:
+ # run entire tests for default, distributed and inductor workflows → use test.sh
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ # distributed test process
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ # inductor test process
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ if "distributed" in workflow_list:
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ if "inductor" in workflow_list:
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ # Run priority test for each workflow
+ elif priority_tests and not default_list and not distributed_list and not inductor_list:
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ # distributed test process
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ # will not run inductor priority tests
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ if "distributed" in workflow_list:
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ if "inductor" in workflow_list:
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ # Run specified tests for each workflow
+ elif (default_list or distributed_list or inductor_list) and not test_config and not priority_tests:
+ if default_list:
+ default_workflow_list = []
+ for item in default_list:
+ default_workflow_list.append(item)
+ res_default_selected = run_selected_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src, default_workflow_list)
+ res_all_tests_dict["default"] = res_default_selected
+ if distributed_list:
+ distributed_workflow_list = []
+ for item in distributed_list:
+ distributed_workflow_list.append(item)
+ res_distributed_selected = run_selected_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src, distributed_workflow_list)
+ res_all_tests_dict["distributed"] = res_distributed_selected
+ if inductor_list:
+ inductor_workflow_list = []
+ for item in inductor_list:
+ inductor_workflow_list.append(item)
+ res_inductor_selected = run_selected_tests("inductor", test_run_test_path, overall_logs_path_current_run, test_reports_src, inductor_workflow_list)
+ res_all_tests_dict["inductor"] = res_inductor_selected
+ else:
+ raise Exception("Invalid test configurations!")
+
+ # restore environment variables
+ os.environ.clear()
+ os.environ.update(_environ)
+
+ # restore files
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(replace_text, search_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ return res_all_tests_dict
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Run PyTorch unit tests and generate xml results summary', formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('--test_config', nargs='+', default=[], type=str, help="space-separated list of test workflows to be executed eg. 'default distributed'")
+ parser.add_argument('--priority_tests', action='store_true', help="run priority tests only")
+ parser.add_argument('--default_list', nargs='+', default=[], help="space-separated list of 'default' config test suites/files to be executed eg. 'test_weak test_dlpack'")
+ parser.add_argument('--distributed_list', nargs='+', default=[], help="space-separated list of 'distributed' config test suites/files to be executed eg. 'distributed/test_c10d_common distributed/test_c10d_nccl'")
+ parser.add_argument('--inductor_list', nargs='+', default=[], help="space-separated list of 'inductor' config test suites/files to be executed eg. 'inductor/test_torchinductor test_ops'")
+ parser.add_argument('--pytorch_root', default='.', type=str, help="PyTorch root directory")
+ parser.add_argument('--skip_rerun', action='store_true', help="skip rerun process")
+ parser.add_argument('--example_output', type=str, help="{'workflow_name': {\n"
+ " test_file_and_status(file_name='workflow_aggregate', status='STATISTICS'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='ERROR'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='FAILED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='PASSED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='SKIPPED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='STATISTICS'): {} \n"
+ "}}\n")
+ parser.add_argument('--example_usages', type=str, help="RUN ALL TESTS: python3 run_pytorch_unit_tests.py \n"
+ "RUN PRIORITY TESTS: python3 run_pytorch_unit_tests.py --test_config distributed --priority_test \n"
+ "RUN SELECTED TESTS: python3 run_pytorch_unit_tests.py --default_list test_weak test_dlpack --inductor_list inductor/test_torchinductor")
+ return parser.parse_args()
+
+def check_num_gpus_for_distributed():
+ p = subprocess.run("rocminfo | grep -cE 'Name:\s+gfx'", shell=True, capture_output=True, text=True)
+ num_gpus_visible = int(p.stdout)
+ assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests"
+
+def main():
+ args = parse_args()
+ all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun)
+ pprint(dict(all_tests_results))
+
+if __name__ == "__main__":
+ main()
diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt
index 71d3ef714fbe1..a1e9df4725c57 100644
--- a/.ci/docker/ci_commit_pins/triton.txt
+++ b/.ci/docker/ci_commit_pins/triton.txt
@@ -1 +1 @@
-bbb06c0334a6772b92d24bde54956e675c8c6604
+d704bc6e69c1a588c8edd3cbb67505d554ed65f6
diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh
index f48140952c3ac..8e714bcb6cd32 100755
--- a/.ci/docker/common/install_triton.sh
+++ b/.ci/docker/common/install_triton.sh
@@ -21,7 +21,7 @@ elif [ -n "${TRITON_CPU}" ]; then
TRITON_REPO="https://github.com/triton-lang/triton-cpu"
TRITON_TEXT_FILE="triton-cpu"
else
- TRITON_REPO="https://github.com/triton-lang/triton"
+ TRITON_REPO="https://github.com/ROCm/triton"
TRITON_TEXT_FILE="triton"
fi
diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index 583136d7df2f1..d3da8e93c639a 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -10,7 +10,7 @@ boto3==1.35.42
#Pinned versions: 1.19.12, 1.16.34
#test that import:
-click
+click==8.3.0
#Description: Command Line Interface Creation Kit
#Pinned versions:
#test that import:
@@ -63,7 +63,7 @@ lark==0.12.0
#Pinned versions: 0.12.0
#test that import:
-librosa>=0.6.2 ; python_version < "3.11" and platform_machine != "s390x"
+librosa==0.11.0 ; python_version < "3.11" and platform_machine != "s390x"
librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x"
#Description: A python package for music and audio analysis
#Pinned versions: >=0.6.2
@@ -113,9 +113,8 @@ ninja==1.11.1.3
#test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py
numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x"
-numba==0.55.2 ; python_version == "3.9" and platform_machine != "s390x"
-numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x"
-numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
+numba==0.60.0 ; python_version == "3.9" and platform_machine != "s390x"
+numba==0.61.2 ; python_version > "3.9" and platform_machine != "s390x"
#Description: Just-In-Time Compiler for Numerical Functions
#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
#test that import: test_numba_integration.py
@@ -134,12 +133,10 @@ numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
#test_nn.py, test_namedtensor.py, test_linalg.py, test_jit_cuda_fuser.py,
#test_jit.py, test_indexing.py, test_datapipe.py, test_dataloader.py,
#test_binary_ufuncs.py
-numpy==1.22.4; python_version == "3.9" or python_version == "3.10"
-numpy==1.26.2; python_version == "3.11" or python_version == "3.12"
-numpy==2.1.2; python_version >= "3.13"
+numpy==2.0.2 ; python_version == "3.9"
+numpy==2.1.2 ; python_version > "3.9"
-pandas==2.0.3; python_version < "3.13"
-pandas==2.2.3; python_version >= "3.13"
+pandas==2.2.3
#onnxruntime
#Description: scoring engine for Open Neural Network Exchange (ONNX) models
@@ -169,10 +166,11 @@ pillow==11.0.0
#Pinned versions: 10.3.0
#test that import:
-protobuf==5.29.4
-#Description: Google's data interchange format
-#Pinned versions: 5.29.4
-#test that import: test_tensorboard.py, test/onnx/*
+protobuf==3.20.2 ; python_version <= "3.12"
+protobuf==4.25.1 ; python_version == "3.13"
+#Description: Google’s data interchange format
+#Pinned versions: 3.20.1
+#test that import: test_tensorboard.py
psutil
#Description: information on running processes and system utilization
@@ -194,7 +192,7 @@ pytest-flakefinder==1.1.0
#Pinned versions: 1.1.0
#test that import:
-pytest-rerunfailures>=10.3
+pytest-rerunfailures==14.0
#Description: plugin for rerunning failure tests in pytest
#Pinned versions:
#test that import:
@@ -250,8 +248,8 @@ scikit-image==0.22.0 ; python_version >= "3.10"
#Pinned versions: 0.20.3
#test that import:
-scipy==1.10.1 ; python_version <= "3.11"
-scipy==1.14.1 ; python_version >= "3.12"
+scipy==1.13.1 ; python_version == "3.9"
+scipy==1.14.1 ; python_version > "3.9"
# Pin SciPy because of failing distribution tests (see #60347)
#Description: scientific python
#Pinned versions: 1.10.1
@@ -265,7 +263,7 @@ scipy==1.14.1 ; python_version >= "3.12"
#test that import:
# needed by torchgen utils
-typing-extensions>=4.10.0
+typing_extensions==4.15.0
#Description: type hints for python
#Pinned versions:
#test that import:
@@ -275,7 +273,7 @@ typing-extensions>=4.10.0
#Pinned versions:
#test that import:
-unittest-xml-reporting<=3.2.0,>=2.0.0
+unittest-xml-reporting==3.2.0
#Description: saves unit test results to xml
#Pinned versions:
#test that import:
@@ -286,7 +284,7 @@ lintrunner==0.12.7
#Pinned versions: 0.12.7
#test that import:
-redis>=4.0.0
+redis==6.4.0
#Description: redis database
#test that import: anything that tests OSS caching/mocking (inductor/test_codecache.py, inductor/test_max_autotune.py)
@@ -310,8 +308,7 @@ z3-solver==4.15.1.0 ; platform_machine != "s390x"
#Pinned versions:
#test that import:
-tensorboard==2.13.0 ; python_version < "3.13"
-tensorboard==2.18.0 ; python_version >= "3.13"
+tensorboard==2.18.0
#Description: Also included in .ci/docker/requirements-docs.txt
#Pinned versions:
#test that import: test_tensorboard
@@ -323,7 +320,8 @@ pywavelets==1.7.0 ; python_version >= "3.12"
#Pinned versions: 1.4.1
#test that import:
-lxml==5.3.0
+lxml==5.3.0 ; python_version <= "3.12"
+lxml==6.0.0 ; python_version == "3.13"
#Description: This is a requirement of unittest-xml-reporting
# Python-3.9 binaries
@@ -335,8 +333,9 @@ sympy==1.13.3
#Pinned versions:
#test that import:
-onnx==1.18.0
-#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal
+onnx==1.16.1 ; python_version <= "3.12"
+onnx==1.18.0 ; python_version == "3.13"
+#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
#Pinned versions:
#test that import:
@@ -360,10 +359,10 @@ pwlf==2.2.1
#test that import: test_sac_estimator.py
# To build PyTorch itself
-pyyaml
-pyzstd
-setuptools>=70.1.0
-six
+PyYAML==6.0.3
+pyzstd==0.18.0
+setuptools==79.0.1
+six==1.17.0
scons==4.5.2 ; platform_machine == "aarch64"
diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh
index f12a3ac075175..aa82d36aa7ce6 100755
--- a/.circleci/scripts/binary_populate_env.sh
+++ b/.circleci/scripts/binary_populate_env.sh
@@ -5,7 +5,9 @@ export TZ=UTC
tagged_version() {
GIT_DIR="${workdir}/pytorch/.git"
GIT_DESCRIBE="git --git-dir ${GIT_DIR} describe --tags --match v[0-9]*.[0-9]*.[0-9]*"
- if [[ ! -d "${GIT_DIR}" ]]; then
+ if [[ -n "${CIRCLE_TAG:-}" ]]; then
+ echo "${CIRCLE_TAG}"
+ elif [[ ! -d "${GIT_DIR}" ]]; then
echo "Abort, abort! Git dir ${GIT_DIR} does not exists!"
kill $$
elif ${GIT_DESCRIBE} --exact >/dev/null; then
@@ -69,6 +71,8 @@ fi
export PYTORCH_BUILD_NUMBER=1
+# This part is done in the builder scripts so commenting the duplicate code
+: <<'BLOCK_COMMENT'
# Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS
TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt)
TRITON_CONSTRAINT="platform_system == 'Linux'"
@@ -110,6 +114,7 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_B
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}"
fi
fi
+BLOCK_COMMENT
USE_GLOO_WITH_OPENSSL="ON"
if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then
diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py
index 11fa8404273d3..f2851e3317256 100644
--- a/.github/scripts/build_triton_wheel.py
+++ b/.github/scripts/build_triton_wheel.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
import os
+import re
import shutil
import sys
from pathlib import Path
@@ -50,6 +51,30 @@ def patch_init_py(
with open(path, "w") as f:
f.write(orig)
+def get_rocm_version() -> str:
+ rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
+ rocm_version = "0.0.0"
+ rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
+ if not os.path.isfile(rocm_version_h):
+ rocm_version_h = f"{rocm_path}/include/rocm_version.h"
+ # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
+ if os.path.isfile(rocm_version_h):
+ RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
+ RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
+ RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
+ major, minor, patch = 0, 0, 0
+ for line in open(rocm_version_h):
+ match = RE_MAJOR.search(line)
+ if match:
+ major = int(match.group(1))
+ match = RE_MINOR.search(line)
+ if match:
+ minor = int(match.group(1))
+ match = RE_PATCH.search(line)
+ if match:
+ patch = int(match.group(1))
+ rocm_version = str(major)+"."+str(minor)+"."+str(patch)
+ return rocm_version
def build_triton(
*,
@@ -64,14 +89,24 @@ def build_triton(
if "MAX_JOBS" not in env:
max_jobs = os.cpu_count() or 1
env["MAX_JOBS"] = str(max_jobs)
-
+ if not release:
+ # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
+ # while release build should only include the version, i.e. 2.1.0
+ rocm_version = get_rocm_version()
+ version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
+ version += version_suffix
with TemporaryDirectory() as tmpdir:
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"
triton_repo = "https://github.com/openai/triton"
if device == "rocm":
- triton_pkg_name = "pytorch-triton-rocm"
+ triton_repo = "https://github.com/ROCm/triton"
+ rocm_version = get_rocm_version() # e.g., "7.0.1"
+ if tuple(map(int, rocm_version.split("."))) > (7, 0, 0):
+ triton_pkg_name = "triton"
+ else:
+ triton_pkg_name = "pytorch-triton-rocm"
elif device == "xpu":
triton_pkg_name = "pytorch-triton-xpu"
triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
@@ -89,6 +124,7 @@ def build_triton(
# change built wheel name and version
env["TRITON_WHEEL_NAME"] = triton_pkg_name
+ env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
if with_clang_ldd:
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ce7890f002d3b..91181735750d6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -873,7 +873,7 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform"
ON
- "USE_CUDA OR USE_ROCM;NOT MSVC"
+ "(USE_CUDA AND NOT MSVC) OR USE_ROCM"
OFF)
cmake_dependent_option(
@@ -908,7 +908,7 @@ cmake_dependent_option(
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
- if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
+ if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index 6c095680733fe..b30d8336e8ec9 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -301,13 +301,14 @@ IF(USE_FBGEMM_GENAI)
# Add additional HIPCC compiler flags for performance
set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS
- -mllvm
- -amdgpu-coerce-illegal-types=1
-mllvm
-enable-post-misched=0
-mllvm
-greedy-reverse-local-assignment=1
-fhip-new-launch-api)
+ if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0")
+ list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
+ endif()
hip_add_library(
fbgemm_genai STATIC
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
index 4d48084b0ab89..7a8d02be530e3 100644
--- a/aten/src/ATen/Context.cpp
+++ b/aten/src/ATen/Context.cpp
@@ -180,7 +180,7 @@ void Context::setUserEnabledNNPACK(bool e) {
}
bool Context::allowTF32CuDNN(const std::string& op) const {
- if (op.size() == 0){
+ if (op.empty()){
bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32";
bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32";
TORCH_CHECK(
@@ -281,9 +281,6 @@ bool Context::userEnabledOverrideableSDP() const {
static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG";
static constexpr const std::array cublas_deterministic_configs = {":4096:8", ":16:8"};
-#ifdef USE_ROCM
-static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
-#endif
bool Context::checkCuBLASConfigDeterministic() {
// If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config
@@ -343,12 +340,6 @@ void Context::setImmediateMiopen(bool b) {
}
bool Context::allowTF32CuBLAS() const {
-#ifdef USE_ROCM
- const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
- if (allow_tf32 != true) {
- return false;
- }
-#endif
bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32";
TORCH_CHECK(
@@ -362,14 +353,6 @@ bool Context::allowTF32CuBLAS() const {
}
void Context::setAllowTF32CuBLAS(bool b) {
-#ifdef USE_ROCM
- const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
- if (allow_tf32 != true) {
- C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
- << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
- return;
- }
-#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee");
}
@@ -443,7 +426,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string&
std::string msg;
auto iterp = _fp32_precisions.find(backend);
TORCH_CHECK(iterp != _fp32_precisions.end());
- for (auto p : iterp->second) {
+ for (const auto& p : iterp->second) {
msg += p;
msg += " ";
}
diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h
index 4300217235b84..06bcc5d4f49b8 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.h
+++ b/aten/src/ATen/core/boxing/KernelFunction.h
@@ -6,8 +6,6 @@
#include
#include
#include
-#include
-#include
#include
namespace c10 {
@@ -19,9 +17,6 @@ class OperatorHandle;
struct OperatorKernel;
class KernelFunction;
-class KernelToken;
-class SafeKernelFunction;
-
template
using has_symint = std::disjunction<
std::is_same,
@@ -95,12 +90,6 @@ class TORCH_API KernelFunction final {
BoxedKernel::BoxedKernelFunction_withDispatchKeys;
KernelFunction();
- ~KernelFunction();
-
- KernelFunction(const KernelFunction& other);
- KernelFunction& operator=(const KernelFunction& other);
-
- KernelFunction(KernelFunction&&) noexcept = default;
// Fast path for dispatch to allow not touching the boxed kernel in
// the common case where unboxed is available.
@@ -273,9 +262,6 @@ class TORCH_API KernelFunction final {
// For testing internal invariants only
bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
- // Register a token to be invalidated when this KernelFunction is destroyed
- void registerToken(std::weak_ptr token) const;
-
private:
explicit KernelFunction(
std::unique_ptr functor,
@@ -290,50 +276,6 @@ class TORCH_API KernelFunction final {
BoxedKernel boxed_kernel_func_;
void* unboxed_kernel_func_;
void* sym_unboxed_kernel_func_;
- // List of tokens that need to be invalidated when this KernelFunction is
- // destroyed (lazy allocation to save memory when empty)
- mutable std::unique_ptr>> tokens_;
-};
-
-// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
-// destroyed
-class KernelToken {
- public:
- bool isValid() const;
- void invalidate();
-
- private:
- std::atomic invalid_{false};
-};
-
-class SafeKernelFunction {
- public:
- SafeKernelFunction(
- const KernelFunction* kernel,
- std::string debug,
- std::shared_ptr opHandle);
-
- // Safe callBoxed - checks token validity first
- void callBoxed(
- const OperatorHandle& opHandle,
- DispatchKeySet dispatchKeySet,
- Stack* stack) const;
-
- // Get debug information
- const std::string& debug() const {
- return debug_;
- }
-
- // Get the OpHandle that lives on this SafeKernelFunction
- const OperatorHandle& opHandle() const {
- return *opHandle_;
- }
-
- private:
- KernelFunction kernel_;
- std::shared_ptr token_;
- std::string debug_;
- std::shared_ptr opHandle_;
};
} // namespace c10
diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h
index 672309ec19a2c..a89a0e8952b6e 100644
--- a/aten/src/ATen/core/boxing/KernelFunction_impl.h
+++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h
@@ -24,36 +24,6 @@ inline KernelFunction::KernelFunction()
unboxed_kernel_func_(nullptr),
sym_unboxed_kernel_func_(nullptr) {}
-inline KernelFunction::~KernelFunction() {
- if (tokens_) {
- for (auto& weak_token : *tokens_) {
- if (auto token = weak_token.lock()) {
- token->invalidate();
- }
- }
- }
-}
-
-inline KernelFunction::KernelFunction(const KernelFunction& other)
- : boxed_kernel_func_(other.boxed_kernel_func_),
- unboxed_kernel_func_(other.unboxed_kernel_func_),
- sym_unboxed_kernel_func_(other.sym_unboxed_kernel_func_) {
- // tokens_ is intentionally not copied as we only care about invalidating
- // tokens if the original KernelFunction is destroyed
-}
-
-inline KernelFunction& KernelFunction::operator=(const KernelFunction& other) {
- if (this != &other) {
- boxed_kernel_func_ = other.boxed_kernel_func_;
- unboxed_kernel_func_ = other.unboxed_kernel_func_;
- sym_unboxed_kernel_func_ = other.sym_unboxed_kernel_func_;
-
- // tokens_ is intentionally not copied as we only care about invalidating
- // tokens if the original KernelFunction is destroyed
- }
- return *this;
-}
-
inline KernelFunction::KernelFunction(
std::unique_ptr functor,
InternalBoxedKernelFunction* boxed_kernel_func,
@@ -187,14 +157,6 @@ C10_ALWAYS_INLINE Return KernelFunction::call(
std::forward(args)...);
}
-inline void KernelFunction::registerToken(
- std::weak_ptr token) const {
- if (!tokens_) {
- tokens_ = std::make_unique>>();
- }
- tokens_->push_back(std::move(token));
-}
-
inline KernelFunction KernelFunction::makeFromBoxedKernel(
BoxedKernel boxed_fn) {
return KernelFunction(
@@ -355,38 +317,4 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
std::forward(lambda)));
}
-inline bool KernelToken::isValid() const {
- return !invalid_.load(std::memory_order_acquire);
-}
-
-inline void KernelToken::invalidate() {
- invalid_.store(true, std::memory_order_release);
-}
-
-inline SafeKernelFunction::SafeKernelFunction(
- const KernelFunction* kernel,
- std::string debug,
- std::shared_ptr opHandle)
- : kernel_(kernel ? *kernel : KernelFunction()),
- token_(std::make_shared()),
- debug_(std::move(debug)),
- opHandle_(std::move(opHandle)) {
- // Register the token with the original kernel so it gets invalidated when the
- // kernel is destroyed
- if (kernel) {
- kernel->registerToken(token_);
- }
-}
-
-inline void SafeKernelFunction::callBoxed(
- const OperatorHandle& opHandle,
- DispatchKeySet dispatchKeySet,
- Stack* stack) const {
- TORCH_CHECK(
- token_ && token_->isValid(),
- "SafeKernelFunction has been invalidated ",
- debug_);
- kernel_.callBoxed(opHandle, dispatchKeySet, stack);
-}
-
} // namespace c10
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index 43eb0028c70fe..bc043df6a93e9 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -487,10 +487,6 @@ class TORCH_API OperatorHandle {
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
}
- SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
- return operatorDef_->op.getComputedKernelForDispatchKey(k);
- }
-
std::string dumpComputedTable() const {
return operatorDef_->op.dumpComputedTable();
}
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
index c172e9b9c6096..b4063fb720be0 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
@@ -315,42 +315,6 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
return nullptr;
}
-SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey(
- DispatchKey k) const {
- TORCH_CHECK(
- !isAliasDispatchKey(k),
- "Alias keys do not have runtime kernel registrations.");
- const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
- TORCH_CHECK(
- dispatchTable_[dispatch_ix].isValid(),
- "no kernel for ",
- k,
- " for ",
- name_);
-
- // Get the KernelFunction object from kernels_ to pass to SafeKernelFunction
-
- // The KernelFunction object in dispatchTable_ is a copy of the KernelFunction
- // in the AnnotatedKernel in kernels_. A KernelFunction is only truly
- // deregistered when the kernel is removed from kernels_. However, the
- // KernelFunction in dispatchTable_ might be removed before it is deregistered
- // (when a newer kernel is registered). Therefore, here we want to return a
- // SafeKernelFunction that is backed by the original KernelFunction in
- // kernels_, so that we only invalidate it when the kernel is deregistered.
- auto [annotatedKernel, _] =
- computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
-
- // Use findSchemaOrThrow to get OpHandle for the OperatorEntry
- auto& dispatcher = c10::Dispatcher::singleton();
- auto opHandle = dispatcher.findSchemaOrThrow(
- name_.name.c_str(), name_.overload_name.c_str());
-
- return SafeKernelFunction(
- &annotatedKernel.kernel,
- annotatedKernel.debug,
- std::make_shared(opHandle));
-}
-
const std::vector& OperatorEntry::getTags() const {
#if defined C10_MOBILE
TORCH_CHECK(false, "tags are not saved for Mobile");
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h
index 59b54ce1d9d32..83200ff9c94ff 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.h
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.h
@@ -217,8 +217,6 @@ class TORCH_API OperatorEntry final {
const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
// Returns true if the "computed table" has an entry for a particular key.
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
- // Returns a KernelFunction corresponding to the kernel in dispatchTable
- SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector& getTags() const;
void setReportErrorCallback_(std::unique_ptr callback);
diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh
index 60e1a19c1aacf..a65db3f2df12a 100644
--- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh
+++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh
@@ -45,6 +45,24 @@ struct OffsetCalculator {
C10_HOST_DEVICE offset_type get(index_t linear_idx) const {
offset_type offsets;
+
+#if defined(USE_ROCM)
+ if ((dims > 0) && (dims <= 2)) {
+ auto divmod = sizes_[0].divmod(linear_idx);
+ #pragma unroll
+ for (int arg = 0; arg < NARGS; arg++)
+ offsets[arg] = divmod.mod * strides_[0][arg];
+ if (dims >= 2) {
+ divmod = sizes_[1].divmod(divmod.div);
+ #pragma unroll
+ for (int arg = 0; arg < NARGS; arg++)
+ offsets[arg] += divmod.mod * strides_[1][arg];
+ }
+ // [...]
+ return offsets;
+ }
+#endif
+
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = 0;
diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh
index 12ad84a15b180..ee28c5c1693f4 100644
--- a/aten/src/ATen/native/cuda/CUDALoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDALoops.cuh
@@ -999,12 +999,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
dtypes[i] = iter.dtype(i);
}
auto offset_calc = ::make_offset_calculator(iter);
+#ifdef USE_ROCM
+ constexpr int grp_sz = 128;
+ launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
+ if (unrl) {
+ auto offsets0 = offset_calc.get(idx);
+ auto offsets1 = offset_calc.get(idx + grp_sz);
+ auto offsets2 = offset_calc.get(idx + grp_sz * 2);
+ auto offsets3 = offset_calc.get(idx + grp_sz * 3);
+ void* out0 = data[0] + offsets0[0];
+ void* out1 = data[0] + offsets1[0];
+ void* out2 = data[0] + offsets2[0];
+ void* out3 = data[0] + offsets3[0];
+ arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1);
+ arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1);
+ arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1);
+ arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1);
+ c10::cast_and_store(dtypes[0], out0, result0);
+ c10::cast_and_store(dtypes[0], out1, result1);
+ c10::cast_and_store(dtypes[0], out2, result2);
+ c10::cast_and_store(dtypes[0], out3, result3);
+ } else {
+ auto offsets = offset_calc.get(idx);
+ void* out = data[0] + offsets[0];
+ arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
+ c10::cast_and_store(dtypes[0], out, result);
+ }
+ });
+#else
launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
c10::cast_and_store(dtypes[0], out, result);
});
+#endif
}
}
diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu
index 59b0426bab1f0..62a07e1e28c86 100644
--- a/aten/src/ATen/native/cuda/Copy.cu
+++ b/aten/src/ATen/native/cuda/Copy.cu
@@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
});
}
+#ifdef USE_ROCM
+void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
+ gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) {
+ return static_cast(value);
+ });
+}
+void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
+ gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) {
+ return static_cast(value);
+ });
+}
+#endif
+
void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType dtype = iter.dtype(0);
ScalarType other_dtype = iter.dtype(1);
@@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
} else {
float16_copy_kernel_cuda(iter);
}
- } else if (isBitsType(dtype)) {
+ }
+#ifdef USE_ROCM
+ else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) {
+ if (iter.dtype(1) == kBFloat16) {
+ bfloat16tofloat32_copy_kernel_cuda(iter);
+ } else {
+ float16tofloat32_copy_kernel_cuda(iter);
+ }
+ }
+#endif
+ else if (isBitsType(dtype)) {
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu
index 02feb55cb69d6..dacef18c79b68 100644
--- a/aten/src/ATen/native/cuda/Indexing.cu
+++ b/aten/src/ATen/native/cuda/Indexing.cu
@@ -59,7 +59,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() {
#ifdef USE_ROCM
#define SKIP_SORTED_INDICES 32
template
-__global__ void indexing_backward_kernel(
+__global__ void indexing_backward_kernel_many_indices(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
using opmath_t = at::opmath_type;
@@ -254,7 +254,8 @@ __global__ void indexing_backward_kernel_stride_1(
}
}
}
-#else
+#endif
+
template
__global__ void indexing_backward_kernel(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -333,6 +334,7 @@ __global__ void indexing_backward_kernel(
}
}
+#ifndef USE_ROCM
template
__global__ void indexing_backward_kernel_stride_1(
const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
@@ -708,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y,
+ grid.z);
dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z);
size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t);
#define KERNEL_GRID new_grid
@@ -780,11 +785,43 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List= 200000)
+ AT_DISPATCH_V2(
+ expandedValue.scalar_type(),
+ "indexing_backward_many_indices",
+ AT_WRAP([&] {
+ indexing_backward_kernel_many_indices<<>>(
+ sorted_indices.const_data_ptr(),
+ orig_indices.const_data_ptr(),
+ expandedValue.const_data_ptr(),
+ src_.mutable_data_ptr(),
+ num_indices,
+ sliceSize,
+ strideBefore,
+ nElemBefore,
+ accumulate);
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
+ }),
+ AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
+ // AT_EXPAND(AT_FLOAT8_TYPES),
+ // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True
+ // should not be supported here, then reenable AT_FLOAT8_DTYPES
+ kFloat8_e4m3fn,
+ kFloat8_e5m2,
+ kFloat8_e4m3fnuz,
+ kFloat8_e5m2fnuz,
+ kComplexHalf,
+ kHalf,
+ kBool,
+ kBFloat16);
+ else
+#endif
AT_DISPATCH_V2(
expandedValue.scalar_type(),
"indexing_backward",
AT_WRAP([&] {
- indexing_backward_kernel<<>>(
+ indexing_backward_kernel<<>>(
sorted_indices.const_data_ptr(),
orig_indices.const_data_ptr(),
expandedValue.const_data_ptr(),
diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh
index 088aa517aa23a..0e3fc88b569c0 100644
--- a/aten/src/ATen/native/cuda/Normalization.cuh
+++ b/aten/src/ATen/native/cuda/Normalization.cuh
@@ -23,7 +23,7 @@ namespace at::native {
// The maximum number of threads in a block
#if defined(USE_ROCM)
-constexpr int MAX_BLOCK_SIZE = 256;
+constexpr int MAX_BLOCK_SIZE = 1024;
#else
constexpr int MAX_BLOCK_SIZE = 512;
#endif
@@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u;
// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static int getNumThreads(int nElem) {
#if defined(USE_ROCM)
- int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE };
+ int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE };
#else
int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE };
#endif
@@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
// first the reductions each thread does separately
scalar_t sum = static_cast(0);
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
+#if defined(USE_ROCM)
+ constexpr int UNRL = 4; // load deserilize factor
+ scalar_t tmp[UNRL];
+ for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
+#pragma unroll
+ for (int u = 0; u < UNRL; u++)
+ tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
+#pragma unroll
+ for (int u = 0; u < UNRL; u++)
+ if (x+u*blockDim.x < tensor.size(2))
+ sum += tmp[u];
+ }
+#else
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
sum += op(batch, plane, x);
}
+#endif
}
__shared__ scalar_t shared[C10_WARP_SIZE];
SumReduceOp reduce_op;
@@ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel(
stat_accscalar_t var_n = 0;
int n = 0;
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
+#if defined(USE_ROCM)
+ constexpr int UNRL = 4;
+ stat_accscalar_t v_[UNRL];
+ for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
+ for (int u = 0; u < UNRL; u++)
+ v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
+ for (int u = 0; u < UNRL; u++) {
+ if (x+u*blockDim.x < input.size(2)) {
+ stat_accscalar_t d1 = v_[u] - avg;
+ n++;
+ avg += d1 / n;
+ var_n += d1 * (v_[u] - avg);
+ }
+ }
+ }
+#else
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
stat_accscalar_t v = input[batch][plane][x];
stat_accscalar_t d1 = v - avg;
@@ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel(
avg += d1 / n;
var_n += d1 * (v - avg);
}
+#endif
}
// first warpSum to get one value per thread to
diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu
index b891750891d58..b46bbaa6500b9 100644
--- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu
+++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu
@@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame(
}
}
+#ifdef USE_ROCM
+// Helper function to compute output pixel range that can contribute to input pixel
+template
+__device__ __forceinline__ void compute_output_range(
+ int input_pos,
+ accscalar_t scale,
+ int output_size,
+ bool align_corners,
+ int& min_output,
+ int& max_output) {
+ accscalar_t lo, hi;
+ if (align_corners) {
+ lo = static_cast(input_pos - 1) / scale;
+ hi = static_cast(input_pos + 1) / scale;
+ } else {
+ lo = (input_pos - static_cast(0.5)) / scale - static_cast(0.5);
+ hi = (input_pos + static_cast(1.5)) / scale - static_cast(0.5);
+ }
+ min_output = max(0, static_cast(std::ceil(lo)));
+ max_output = min(output_size - 1, static_cast(std::floor(hi)));
+}
+#endif
+
// Backward (adjoint) operation 1 <- 2 (accumulates)
template
C10_LAUNCH_BOUNDS_1(1024)
@@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame(
const bool align_corners,
scalar_t* __restrict__ idata,
const scalar_t* __restrict__ odata) {
- const size_t o_numel = nc * width2 * height2;
+ // In C++, integer multiplication, like in standard arithmetic, is generally commutative.
const size_t i_numel = nc * width1 * height1;
+#ifdef USE_ROCM
+ for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel;
+ index += blockDim.x * gridDim.x) {
+ // Decode input pixel coordinates
+ size_t index_temp = index;
+ const int w1 = index_temp % width1;
+ index_temp /= width1;
+ const int h1 = index_temp % height1;
+ const size_t nc_idx = index_temp / height1;
+
+ accscalar_t grad_sum = 0;
+
+ // Find range of output pixels that could interpolate from this input pixel
+ int h2_min, h2_max, w2_min, w2_max;
+ compute_output_range(h1, rheight, height2, align_corners, h2_min, h2_max);
+ compute_output_range(w1, rwidth, width2, align_corners, w2_min, w2_max);
+
+ // Iterate over potential output pixels
+ for (int h2 = h2_min; h2 <= h2_max; h2++) {
+ for (int w2 = w2_min; w2 <= w2_max; w2++) {
+ // Compute source coordinates for this output pixel
+ const accscalar_t h1r = area_pixel_compute_source_index(
+ rheight, h2, align_corners, /*cubic=*/false);
+ const int h1_base = (int)h1r;
+ const int h1p = (h1_base < height1 - 1) ? 1 : 0;
+ const accscalar_t h1lambda = h1r - h1_base;
+ const accscalar_t h0lambda = static_cast(1) - h1lambda;
+
+ const accscalar_t w1r = area_pixel_compute_source_index(
+ rwidth, w2, align_corners, /*cubic=*/false);
+ const int w1_base = (int)w1r;
+ const int w1p = (w1_base < width1 - 1) ? 1 : 0;
+ const accscalar_t w1lambda = w1r - w1_base;
+ const accscalar_t w0lambda = static_cast(1) - w1lambda;
+
+ // Check if our input pixel participates in this interpolation and accumulate all weights
+ // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse
+ // to the same pixel, so we need to accumulate weights from all matching positions
+ accscalar_t weight = 0;
+
+ // Check all four interpolation positions and accumulate weights
+ if (h1 == h1_base && w1 == w1_base) {
+ weight += h0lambda * w0lambda; // top-left
+ }
+ if (h1 == h1_base && w1 == w1_base + w1p) {
+ weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0)
+ }
+ if (h1 == h1_base + h1p && w1 == w1_base) {
+ weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0)
+ }
+ if (h1 == h1_base + h1p && w1 == w1_base + w1p) {
+ weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions)
+ }
+
+ if (weight > 0) {
+ const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2;
+ grad_sum += weight * static_cast(odata[output_idx]);
+ }
+ }
+ }
+
+ // Write accumulated gradient (no atomics needed)
+ idata[index] = static_cast(grad_sum);
+ }
+#else
+ const size_t o_numel = nc * width2 * height2;
for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel;
index += blockDim.x * gridDim.x) {
size_t index_temp = index;
@@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame(
static_cast(h1lambda * w1lambda * d2val),
true);
}
+#endif
}
template
@@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template(
// threads are not covering the whole input tensor.
grad_input.zero_();
- const size_t num_kernels = nbatch * channels * output_height * output_width;
const int num_threads = std::min(
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template(
return;
}
+#ifdef USE_ROCM
+ constexpr bool use_input = true;
+#else
+ constexpr bool use_input = false;
+#endif
+
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] {
@@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale(
input_width, output_width, align_corners, scales_w);
+ const size_t num_kernels = nbatch * channels * output_height * output_width;
+
upsample_bilinear2d_backward_nhwc_out_frame
<<(num_threads)), num_threads, 0, stream>>>(
input_height,
@@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template(
const accscalar_t rwidth = area_pixel_compute_scale(
input_width, output_width, align_corners, scales_w);
+ const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width);
+
upsample_bilinear2d_backward_out_frame
<<(num_threads)),
num_threads,
diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
index 940680eb3682f..81387bcceaf01 100644
--- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu
+++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu
@@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum(
if constexpr (!rms_norm){
U delta = val - curr_sum.mean;
U new_count = curr_sum.count + 1.f;
+#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count);
+#else
U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster
+#endif
return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count};
} else{
return {0.f, curr_sum.sigma2 + val * val, 0};
@@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine(
U count = dataA.count + dataB.count;
U mean, sigma2;
if (count > decltype(dataB.count){0}) {
+#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ auto coef = __builtin_amdgcn_rcpf(count);
+#else
auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division
+#endif
auto nA = dataA.count * coef;
auto nB = dataB.count * coef;
mean = nA*dataA.mean + nB*dataB.mean;
diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu
index b8b43e0086c1a..c2193f2378dd5 100644
--- a/aten/src/ATen/native/transformers/cuda/attention.cu
+++ b/aten/src/ATen/native/transformers/cuda/attention.cu
@@ -95,6 +95,72 @@
#endif
#endif
+#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
+namespace pytorch_flash
+{
+std::tuple<
+ at::Tensor,
+ at::Tensor,
+ at::Tensor,
+ at::Tensor,
+ at::Tensor,
+ at::Tensor,
+ at::Tensor,
+ at::Tensor>
+mha_fwd(
+ const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
+ const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
+ const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
+ std::optional&
+ out_, // batch_size x seqlen_q x num_heads x head_size
+ std::optional&
+ alibi_slopes_, // num_heads or batch_size x num_heads
+ const float p_dropout,
+ const float softmax_scale,
+ bool is_causal,
+ std::optional window_size_left,
+ std::optional window_size_right,
+ const float softcap,
+ const bool return_softmax,
+ std::optional gen_) {
+#if defined(USE_ROCM_CK_SDPA)
+ if (at::globalContext().getROCmFAPreferredBackend() ==
+ at::ROCmFABackend::Ck) {
+ const int non_null_window_left = window_size_left.value_or(-1);
+ const int non_null_window_right = window_size_right.value_or(-1);
+ std::optional dummy_attn_bias = std::nullopt;
+ return mha_fwd_ck(
+ q,
+ k,
+ v,
+ out_,
+ p_dropout,
+ softmax_scale,
+ is_causal,
+ non_null_window_left,
+ non_null_window_right,
+ return_softmax,
+ gen_,
+ dummy_attn_bias); // Not used in flash attention
+ }
+#endif
+ return mha_fwd_aot(
+ q,
+ k,
+ v,
+ out_,
+ alibi_slopes_,
+ p_dropout,
+ softmax_scale,
+ is_causal,
+ window_size_left,
+ window_size_right,
+ return_softmax,
+ gen_);
+}
+}
+#endif
+
namespace at {
namespace cuda::philox {
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index 660aee3647cea..8eec0de7773f3 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -176,6 +176,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) {
}
return false;
}
+ if constexpr(caller_is_meff) {
+ bool is_half = (params.query.dtype() == at::kHalf) ||
+ (params.query.dtype() == at::kBFloat16);
+ const int64_t alignment = is_half ? 8 : 4;
+ if (!(query_size_last % alignment == 0 && query_size_last > 0 &&
+ value_size_last % alignment == 0 && value_size_last > 0)) {
+ if (debug) {
+ TORCH_WARN(
+ "Mem efficient attention requires last dimension of inputs to be divisible by ",
+ alignment,
+ ". ",
+ "Got Query.size(-1): ",
+ query_size_last,
+ ", Key.size(-1): ",
+ params.key.sym_size(-1),
+ ", Value.size(-1): ",
+ params.value.sym_size(-1),
+ " instead.");
+ }
+ return false;
+ }
+ }
return true;
}
diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip
index b5b1ed4292896..2467cb809fdbf 100644
--- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip
+++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip
@@ -462,10 +462,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
using sdp::aotriton_adapter::mk_aotensor;
using sdp::aotriton_adapter::mk_aoscalartensor;
using sdp::aotriton_adapter::mk_philoxtensor;
+ using sdp::aotriton_adapter::mk_atomictensor;
using sdp::aotriton_adapter::cast_dtype;
at::Tensor atomic_counter;
if (is_causal) {
- atomic_counter = at::zeros({1}, q.options());
+ atomic_counter = at::zeros({1}, q.options().dtype(at::kInt));
}
aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype()));
auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
@@ -474,7 +475,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot
auto nullscalar = mk_philoxtensor(nullptr);
auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar;
auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar;
- auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar;
+ auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr);
if (uses_swa || AOTRITON_ALWAYS_V3_API) {
#if AOTRITON_V3_API
using aotriton::v3::flash::CausalType;
diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h
index f6f2240d4f091..71a1959065970 100644
--- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h
+++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h
@@ -270,7 +270,7 @@ std::tuple mha_varle
#endif
TORCH_API
-inline std::tuple<
+std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
@@ -294,42 +294,7 @@ mha_fwd(
std::optional window_size_right,
const float softcap,
const bool return_softmax,
- std::optional gen_) {
-#if defined(USE_ROCM_CK_SDPA)
- if (at::globalContext().getROCmFAPreferredBackend() ==
- at::ROCmFABackend::Ck) {
- const int non_null_window_left = window_size_left.value_or(-1);
- const int non_null_window_right = window_size_right.value_or(-1);
- std::optional dummy_attn_bias = std::nullopt;
- return mha_fwd_ck(
- q,
- k,
- v,
- out_,
- p_dropout,
- softmax_scale,
- is_causal,
- non_null_window_left,
- non_null_window_right,
- return_softmax,
- gen_,
- dummy_attn_bias); // Not used in flash attention
- }
-#endif
- return mha_fwd_aot(
- q,
- k,
- v,
- out_,
- alibi_slopes_,
- p_dropout,
- softmax_scale,
- is_causal,
- window_size_left,
- window_size_right,
- return_softmax,
- gen_);
-}
+ std::optional gen_);
inline std::tuple<
at::Tensor,
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index ef5c2fd4e97de..daceebd8bc889 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1037,6 +1037,22 @@ if(USE_ROCM)
list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling)
endif(CMAKE_BUILD_TYPE MATCHES Debug)
+ # Get EnVar 'PYTORCH_LAYERNORM_FAST_RECIPROCAL' (or default to on).
+ if(DEFINED ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL})
+ set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE $ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL})
+ else()
+ set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE ON)
+ endif()
+
+ set(PYTORCH_LAYERNORM_FAST_RECIPROCAL
+ ${PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE}
+ CACHE BOOL "Enable fast reciprocals within layer normalization." FORCE
+ )
+
+ if(PYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ add_definitions(-DPYTORCH_LAYERNORM_FAST_RECIPROCAL)
+ endif()
+
# needed for compat with newer versions of hip-clang that introduced C++20 mangling rules
list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17)
diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake
index 5d91587746540..f09f77bedb80f 100644
--- a/cmake/External/aotriton.cmake
+++ b/cmake/External/aotriton.cmake
@@ -45,13 +45,89 @@ if(NOT __AOTRITON_INCLUDED)
)
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
set(__AOTRITON_Z "gz")
+ # Set the default __AOTRITON_LIB path
+ if(NOT WIN32)
+ set(__AOTRITON_LIB "lib/libaotriton_v2.so")
+ else()
+ set(__AOTRITON_LIB "lib/aotriton_v2.lib")
+ endif()
+
+ function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
+ # Windows-specific dependencies - build these first
+ if(NOT noimage)
+ message(FATAL_ERROR "noimage must be ON for Windows builds")
+ endif()
+ # Build dlfcn-win32
+ set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32")
+ set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install")
+
+ ExternalProject_Add(${dlfcn-win32_external}
+ GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
+ GIT_TAG v1.4.2
+ PREFIX ${__DLFCN_WIN32_PREFIX}
+ INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR}
+ CMAKE_ARGS
+ -DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR}
+ -DCMAKE_BUILD_TYPE=Release
+ -DCMAKE_C_COMPILER=cl
+ -DCMAKE_CXX_COMPILER=cl
+ -DBUILD_SHARED_LIBS=ON
+ -DBUILD_TESTS=OFF
+ BUILD_BYPRODUCTS
+ "${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib"
+ "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
+ )
+ ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different
+ "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
+ "${__AOTRITON_INSTALL_DIR}/lib/"
+ DEPENDEES install
+ )
+ set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE)
+
+ # Build xz/liblzma
+ set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz")
+ set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install")
+
+ ExternalProject_Add(${xz_external}
+ GIT_REPOSITORY https://github.com/tukaani-project/xz.git
+ GIT_TAG v5.8.1
+ PREFIX ${__XZ_PREFIX}
+ INSTALL_DIR ${__XZ_INSTALL_DIR}
+ CMAKE_ARGS
+ -DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR}
+ -DCMAKE_BUILD_TYPE=Release
+ -DBUILD_SHARED_LIBS=ON
+ -DENABLE_NLS=OFF
+ -DXZ_TOOL_LZMAINFO=OFF
+ -DXZ_TOOL_XZ=OFF
+ -DXZ_TOOL_XZDEC=OFF
+ -DXZ_TOOL_LZMADEC=OFF
+ BUILD_BYPRODUCTS
+ "${__XZ_INSTALL_DIR}/lib/lzma.lib"
+ "${__XZ_INSTALL_DIR}/bin/liblzma.dll"
+ )
+ ExternalProject_Add_Step(${xz_external} copy_to_aotriton
+ COMMAND ${CMAKE_COMMAND} -E copy_if_different
+ "${__XZ_INSTALL_DIR}/bin/liblzma.dll"
+ "${__AOTRITON_INSTALL_DIR}/lib/"
+ DEPENDEES install
+ )
+ set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE)
+ endfunction()
+
function(aotriton_build_from_source noimage project)
if(noimage)
SET(RECURSIVE "OFF")
else()
SET(RECURSIVE "ON")
endif()
+ if(WIN32)
+ message(STATUS "Building AOTriton Windows dependencies")
+ aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
+ endif()
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
+
ExternalProject_Add(${project}
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_SUBMODULES_RECURSE ${RECURSIVE}
@@ -65,12 +141,18 @@ if(NOT __AOTRITON_INCLUDED)
-DAOTRITON_GPU_BUILD_TIMEOUT=0
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NOIMAGE_MODE=${noimage}
- BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
+ -DHIP_PLATFORM=amd
+ $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}>
+ $<$:-Dliblzma_DIR=${liblzma_DIR}>
+ BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}"
USES_TERMINAL_DOWNLOAD TRUE
USES_TERMINAL_CONFIGURE TRUE
USES_TERMINAL_BUILD TRUE
USES_TERMINAL_INSTALL TRUE
)
+ if(WIN32)
+ add_dependencies(${project} dlfcn-win32_external xz_external)
+ endif()
endfunction()
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
@@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED)
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
"${__AOTRITON_INSTALL_DIR}"
- BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
+ BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}"
)
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
@@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED)
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
+
+ # Set up directories
+ set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
+ set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})
+ set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR})
+ set(__DOWNLOAD_NO_EXTRACT "")
+ set(__BUILD_COMMANDS "")
+
+ # On Windows, we need custom tar extraction with UTF-8 support
+ if(WIN32)
+ set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE")
+ set(__BUILD_COMMANDS
+ COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}"
+ COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}"
+ )
+ set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton)
+ endif()
+
ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
- SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
+ DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
+ ${__DOWNLOAD_NO_EXTRACT}
+ SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
+ ${__BUILD_COMMANDS}
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
- "${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
+ "${__AOTRITON_INSTALL_SOURCE_DIR}"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
@@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED)
endforeach()
endforeach()
endif()
- target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
+ target_link_libraries(__caffe2_aotriton INTERFACE "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}")
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
set(AOTRITON_FOUND TRUE)
endif() # __AOTRITON_INCLUDED
diff --git a/cmake/Modules/FindOpenBLAS.cmake b/cmake/Modules/FindOpenBLAS.cmake
index 9ba86ba1ee0f4..21ae9e2521eb2 100644
--- a/cmake/Modules/FindOpenBLAS.cmake
+++ b/cmake/Modules/FindOpenBLAS.cmake
@@ -29,10 +29,15 @@ SET(Open_BLAS_LIB_SEARCH_PATHS
$ENV{OpenBLAS}/lib
$ENV{OpenBLAS_HOME}
$ENV{OpenBLAS_HOME}/lib
- )
+)
+
+SET(Open_BLAS_LIB_NAME openblas)
+IF(DEFINED ENV{OpenBLAS_LIB_NAME})
+ SET(Open_BLAS_LIB_NAME $ENV{OpenBLAS_LIB_NAME})
+ENDIF()
FIND_PATH(OpenBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Open_BLAS_INCLUDE_SEARCH_PATHS})
-FIND_LIBRARY(OpenBLAS_LIB NAMES openblas PATHS ${Open_BLAS_LIB_SEARCH_PATHS})
+FIND_LIBRARY(OpenBLAS_LIB NAMES ${Open_BLAS_LIB_NAME} PATHS ${Open_BLAS_LIB_SEARCH_PATHS})
SET(OpenBLAS_FOUND ON)
diff --git a/docs/source/library.md b/docs/source/library.md
index b31ca95d5b6a3..9d706e2e1080e 100644
--- a/docs/source/library.md
+++ b/docs/source/library.md
@@ -56,7 +56,6 @@ via PyTorch's C++ operator registration APIs).
.. autofunction:: infer_schema
.. autoclass:: torch._library.custom_ops.CustomOpDef
:members: set_kernel_enabled
-.. autofunction:: get_kernel
```
## Low-level APIs
diff --git a/related_commits b/related_commits
new file mode 100644
index 0000000000000..b96cf18c181ab
--- /dev/null
+++ b/related_commits
@@ -0,0 +1,10 @@
+ubuntu|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex
+centos|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex
+ubuntu|pytorch|torchvision|release/0.24|b919bd0c56abbb3c5ca056a3a458af9fd1cabf52|https://github.com/pytorch/vision
+centos|pytorch|torchvision|release/0.24|b919bd0c56abbb3c5ca056a3a458af9fd1cabf52|https://github.com/pytorch/vision
+ubuntu|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data
+centos|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data
+ubuntu|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|https://github.com/pytorch/audio
+centos|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|https://github.com/pytorch/audio
+ubuntu|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao
+centos|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao
diff --git a/requirements-build.txt b/requirements-build.txt
index be19d987f73db..3d21094159d79 100644
--- a/requirements-build.txt
+++ b/requirements-build.txt
@@ -1,10 +1,10 @@
# Build System requirements
-setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0
-cmake>=3.27
-ninja
-numpy
-packaging
-pyyaml
-requests
-six # dependency chain: NNPACK -> PeachPy -> six
-typing-extensions>=4.10.0
+setuptools==79.0.1
+cmake==4.0.0
+ninja==1.11.1.3
+numpy==2.1.2
+packaging==25.0
+pyyaml==6.0.3
+requests==2.32.5
+six==1.17.0 # dependency chain: NNPACK -> PeachPy -> six
+typing_extensions==4.15.0
diff --git a/requirements.txt b/requirements.txt
index fc4b53dfd49ea..824ca112602a0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,15 +5,18 @@
# Install / Development extra requirements
build[uv] # for building sdist and wheel
-expecttest>=0.3.0
-filelock
-fsspec>=0.8.5
-hypothesis
-jinja2
-lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64"
-networkx>=2.5.1
-optree>=0.13.0
-psutil
-sympy>=1.13.3
-typing-extensions>=4.13.2
+expecttest==0.3.0
+filelock==3.20.0
+fsspec==2025.9.0
+hypothesis==5.35.1
+Jinja2==3.1.6
+lintrunner==0.12.7 ; platform_machine != "s390x" and platform_machine != "riscv64"
+networkx==2.8.8
+ninja==1.11.1.3
+numpy==2.0.2 ; python_version == "3.9"
+numpy==2.1.2 ; python_version > "3.9"
+optree==0.13.0
+psutil==7.1.0
+sympy==1.13.3
+typing_extensions==4.15.0
wheel
diff --git a/setup.py b/setup.py
index 11ca48482a761..ae0097465da66 100644
--- a/setup.py
+++ b/setup.py
@@ -162,6 +162,10 @@
# USE_ROCM_CK_SDPA=1
# Enable building CK SDPA backend in ROCm platform
#
+# PYTORCH_LAYERNORM_FAST_RECIPROCAL
+# If set, enables the use of builtin functions for fast reciprocals (1/x) w.r.t.
+# layer normalization. Default: enabled.
+#
# Environment variables we respect (these environment variables are
# conventional and are often understood/set by other software.)
#
diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py
index e930ff787a9a4..ce456596fd55e 100644
--- a/test/dynamo/test_graph_region_tracker.py
+++ b/test/dynamo/test_graph_region_tracker.py
@@ -1,6 +1,5 @@
# Owner(s): ["module: dynamo"]
import contextlib
-import os
import torch
import torch.fx
@@ -196,21 +195,6 @@ def fn(x, y, z):
)
def test_mismatched_global_state(self):
- @contextlib.contextmanager
- def _hip_allow_tf32():
- # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
- # and only for MI300+
- hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
- os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
-
- try:
- yield
- finally:
- if hip_allow_tf32 is not None:
- os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
- else:
- del os.environ["HIPBLASLT_ALLOW_TF32"]
-
def inner_fn(x, y):
x1 = x * 1
y1 = y + 1
@@ -251,31 +235,29 @@ def set_default_dtype_bfloat16():
def reset_default_dtype():
torch.set_default_dtype(old_dtype)
- tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
- with tf32_ctx():
- for ctx in [
- lambda: torch.set_grad_enabled(False),
- torch.autograd.grad_mode.inference_mode,
- lambda: torch.autograd.graph.disable_saved_tensors_hooks(
- "This is not supported"
- ),
- # lambda: torch.set_num_threads(2), : Unsupported
- (set_default_dtype_bfloat16, reset_default_dtype),
- (
- lambda: torch.use_deterministic_algorithms(True),
- lambda: torch.use_deterministic_algorithms(False),
- ),
- # (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
- # lambda: torch.use_deterministic_algorithms(False)), : Unsupported
- create_toggle_fns("allow_bf16_reduced_precision_reduction"),
- create_toggle_fns("allow_fp16_reduced_precision_reduction"),
- create_toggle_fns("allow_tf32"),
- ]:
- self.assertExpectedInline(
- self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
- """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
+ for ctx in [
+ lambda: torch.set_grad_enabled(False),
+ torch.autograd.grad_mode.inference_mode,
+ lambda: torch.autograd.graph.disable_saved_tensors_hooks(
+ "This is not supported"
+ ),
+ # lambda: torch.set_num_threads(2), : Unsupported
+ (set_default_dtype_bfloat16, reset_default_dtype),
+ (
+ lambda: torch.use_deterministic_algorithms(True),
+ lambda: torch.use_deterministic_algorithms(False),
+ ),
+ # (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
+ # lambda: torch.use_deterministic_algorithms(False)), : Unsupported
+ create_toggle_fns("allow_bf16_reduced_precision_reduction"),
+ create_toggle_fns("allow_fp16_reduced_precision_reduction"),
+ create_toggle_fns("allow_tf32"),
+ ]:
+ self.assertExpectedInline(
+ self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
+ """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
- )
+ )
def test_mutation_tracking_simple(self):
def fn(x, y, z):
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 1a9d8e8155e43..0a3891e2dc146 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -8421,43 +8421,24 @@ def write_state(state):
def fn(x):
return x + 1
- import contextlib
-
- @contextlib.contextmanager
- def _hip_allow_tf32():
- # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
- # and only for MI300+
- hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
- os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
-
- try:
- yield
- finally:
- if hip_allow_tf32 is not None:
- os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
- else:
- del os.environ["HIPBLASLT_ALLOW_TF32"]
-
- tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
- with tf32_ctx():
- initial_state = read_state()
- y = torch.randn(10)
- try:
- for round in range(3):
- for i in range(len(initial_state)):
- new_state = [False] * len(initial_state)
- new_state[i] = True
- write_state(new_state)
- assert read_state() == new_state
- last_state.clear()
- fn(y)
- assert last_state == new_state
- if round == 0:
- assert cnt == i + 1
- else:
- assert cnt == len(initial_state)
- finally:
- write_state(initial_state)
+ initial_state = read_state()
+ y = torch.randn(10)
+ try:
+ for round in range(3):
+ for i in range(len(initial_state)):
+ new_state = [False] * len(initial_state)
+ new_state[i] = True
+ write_state(new_state)
+ assert read_state() == new_state
+ last_state.clear()
+ fn(y)
+ assert last_state == new_state
+ if round == 0:
+ assert cnt == i + 1
+ else:
+ assert cnt == len(initial_state)
+ finally:
+ write_state(initial_state)
def test_grad_state_mutated(self):
prior = torch.is_grad_enabled()
diff --git a/test/inductor/test_async_compile.py b/test/inductor/test_async_compile.py
index 5a61ea851eae0..cc94c4c95e01a 100644
--- a/test/inductor/test_async_compile.py
+++ b/test/inductor/test_async_compile.py
@@ -74,7 +74,14 @@ def f(a, b):
return (a @ b).to(torch.float32).sum(dim=1)
# Fake name to make sure the lookup table is name agnostic
- func_def = """
+ # When codegen/triton.py is changed, func_def must be updated
+ loop_header = (
+ "for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):"
+ if torch.version.hip
+ else "for r0_offset in range(0, r0_numel, R0_BLOCK):"
+ )
+
+ func_def = f"""
def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1024
r0_numel = 11776
@@ -87,7 +94,7 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons
rbase = r0_base
x0 = xindex
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
- for r0_offset in range(0, r0_numel, R0_BLOCK):
+ {loop_header}
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py
index 90399546d26ea..6523cddcec6db 100644
--- a/test/inductor/test_combo_kernels.py
+++ b/test/inductor/test_combo_kernels.py
@@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
- @requires_cuda_and_triton
- def test_persistent_reduction_no_x_dim(self):
- def fn(x, y):
- return x.sum(1), y.sum(1)
-
- inps = (
- torch.rand(16, 256, device="cuda"),
- torch.rand(32, 256, device="cuda"),
- )
- torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
- torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
- out_eager = fn(*inps)
- out_compiled = torch.compile(fn)(*inps)
-
- self.assertEqual(out_eager, out_compiled)
- self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
-
@instantiate_parametrized_tests
class ComboKernelDynamicShapesTests(TestCase):
diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py
index 120d8d36b439d..849aefff8a965 100644
--- a/test/inductor/test_flex_decoding.py
+++ b/test/inductor/test_flex_decoding.py
@@ -43,9 +43,6 @@
Tolerances = namedtuple("Tolerances", ["atol", "rtol"])
-# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
-# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
-# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
if torch.version.hip:
torch.set_float32_matmul_precision("highest")
else:
diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py
index 9ef3a18e24234..c67bde87a369b 100644
--- a/test/inductor/test_padding.py
+++ b/test/inductor/test_padding.py
@@ -109,9 +109,6 @@ def setUpClass(cls):
if HAS_GPU:
cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision()
cls.prior_default_device = torch.get_default_device()
- # In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul.
- # In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the
- # logic of allowTF32CuBLAS(), set float32_matmul_precision to highest.
if torch.version.hip:
torch.set_float32_matmul_precision("highest")
else:
diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py
index 41db6b18daba7..6bde7a8c540a4 100644
--- a/test/inductor/test_torchinductor_strided_blocks.py
+++ b/test/inductor/test_torchinductor_strided_blocks.py
@@ -816,6 +816,7 @@ def test_2d_reduction_odd_shapes(
# Check the code for multiple Rn_BLOCK's
self._assert_reduction_ndims(code, 2)
+
@parametrize(
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",
[
diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py
index c0419664d0098..40dca90b16488 100644
--- a/test/nn/test_multihead_attention.py
+++ b/test/nn/test_multihead_attention.py
@@ -17,7 +17,6 @@
instantiate_parametrized_tests,
parametrize as parametrize_test,
run_tests,
- skipIfRocm,
TEST_NUMPY,
TEST_WITH_CROSSREF,
)
@@ -746,7 +745,6 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self):
class TestMultiheadAttentionNNDeviceType(NNTestCase):
- @skipIfRocm(msg="To investigate: yields NaN")
def test_multihead_self_attn_two_masks_fast_path(self, device):
"""
Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path
diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py
index 1c31d5445f915..569d1bac85958 100644
--- a/test/test_binary_ufuncs.py
+++ b/test/test_binary_ufuncs.py
@@ -1480,8 +1480,8 @@ def to_np(value):
self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent)
elif torch.can_cast(torch.result_type(base, exponent), base.dtype):
actual2 = actual.pow_(exponent)
- self.assertEqual(actual, expected)
- self.assertEqual(actual2, expected)
+ self.assertEqual(actual, expected.to(actual))
+ self.assertEqual(actual2, expected.to(actual))
else:
self.assertRaisesRegex(
RuntimeError,
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 293bb2b7e701b..d293601fad138 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -504,6 +504,9 @@ def test_out_of_memory_retry(self):
IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support"
)
def test_set_per_process_memory_fraction(self):
+ if torch.version.hip and ('gfx1101' in torch.cuda.get_device_properties(0).gcnArchName):
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
orig = torch.cuda.get_per_process_memory_fraction(0)
torch.cuda.reset_peak_memory_stats(0)
try:
@@ -759,53 +762,7 @@ def check_workspace_size(inp):
torch._C._cuda_clearCublasWorkspaces()
- @contextlib.contextmanager
- def _hip_allow_tf32(self):
- # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
- # and only for MI300+
- hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
- os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
-
- try:
- yield
- finally:
- if hip_allow_tf32 is not None:
- os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
- else:
- del os.environ["HIPBLASLT_ALLOW_TF32"]
-
- @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing")
- def test_hipblaslt_allow_tf32(self):
- tf32_ctx = self._hip_allow_tf32
- with tf32_ctx():
- os.environ["HIPBLASLT_ALLOW_TF32"] = "0"
- # Save original value of allow_tf32
- orig = torch.backends.cuda.matmul.allow_tf32
- # If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp
- # then matmul.allow_tf32 will return False after this point even if
- # HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed.
- os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
- # Toggle torch.backends.cuda.matmul.allow_tf32 couple of times.
- torch.backends.cuda.matmul.allow_tf32 = not orig
- test1 = torch.backends.cuda.matmul.allow_tf32
- torch.backends.cuda.matmul.allow_tf32 = orig
- test2 = torch.backends.cuda.matmul.allow_tf32
- self.assertNotEqual(test1, test2)
- # Restore original value of allow_tf32
- torch.backends.cuda.matmul.allow_tf32 = orig
-
def test_cublas_allow_tf32_get_set(self):
- """
- We only turn on TF32 for MI300 with a special env var. This is because TF32
- is only available in MI300+ and is in experimental mode (hipblaslt support
- is current WIP)
- """
- tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
-
- with tf32_ctx():
- self._test_cublas_allow_tf32_get_set_inner()
-
- def _test_cublas_allow_tf32_get_set_inner(self):
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
)
@@ -820,12 +777,6 @@ def _test_cublas_allow_tf32_get_set_inner(self):
torch.backends.cuda.matmul.allow_tf32 = orig
def test_float32_matmul_precision_get_set(self):
- tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
-
- with tf32_ctx():
- self._test_float32_matmul_precision_get_set_inner()
-
- def _test_float32_matmul_precision_get_set_inner(self):
orig = torch.get_float32_matmul_precision()
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py
index 491648494f6f0..5a494f5487423 100644
--- a/test/test_custom_ops.py
+++ b/test/test_custom_ops.py
@@ -11,7 +11,6 @@
import tempfile
import typing
import unittest
-from functools import partial
from pathlib import Path
from typing import * # noqa: F403
@@ -4157,148 +4156,6 @@ def test_any_output_is_alias_to_input_or_output(self):
)
)
- def test_library_get_kernel(self):
- """Test registering a custom kernel, using it, then deregistering and verifying error."""
-
- # Register a dummy kernel for arange to the CPU key that returns a tensor of ones
- def dummy_arange_cpu(
- dispatch_keys,
- start,
- end,
- dtype=None,
- layout=torch.strided,
- device=None,
- pin_memory=False,
- ):
- size = max(0, int(end - start))
- return torch.ones(size, dtype=dtype, device=device)
-
- with torch.library._scoped_library("aten", "IMPL") as lib:
- lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True)
-
- kernel = torch.library.get_kernel("aten::arange.start", "CPU")
- dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU)
- result = kernel.call_boxed(dispatch_keys, 0, 5)
-
- self.assertEqual(result, torch.ones(5))
-
- # The kernel should now be invalidated after exiting the scoped_library context
- with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
- kernel.call_boxed(dispatch_keys, 0, 5)
-
- def test_library_get_kernel_with_conditional_dispatch(self):
- """Test registering a custom kernel with conditional dispatch logic."""
-
- def conditional_arange_cpu1(
- original_kernel,
- dispatch_keys,
- start,
- end,
- dtype=None,
- layout=torch.strided,
- device=None,
- pin_memory=False,
- ):
- # If end is even, use the original kernel, otherwise return ones tensor
- if end % 2 == 0:
- op_handle = torch.ops.aten.arange.start._handle
- return original_kernel.call_boxed(
- dispatch_keys,
- start,
- end,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- )
- else:
- size = max(0, int(end - start))
- return torch.ones(size, dtype=dtype, device=device)
-
- def conditional_arange_cpu2(
- original_kernel,
- dispatch_keys,
- start,
- end,
- dtype=None,
- layout=torch.strided,
- device=None,
- pin_memory=False,
- ):
- # If start is even, use the original kernel, otherwise return twos tensor
- if start % 2 == 0:
- op_handle = torch.ops.aten.arange.start._handle
- return original_kernel.call_boxed(
- dispatch_keys,
- start,
- end,
- dtype=dtype,
- layout=layout,
- device=device,
- pin_memory=pin_memory,
- )
- else:
- size = max(0, int(end - start))
- return torch.empty(size, dtype=dtype, device=device).fill_(2)
-
- original_kernel = torch.library.get_kernel("aten::arange.start", "CPU")
- expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6)
- expected_result3, expected_result4, expected_result5 = (
- torch.ones(5),
- torch.arange(0, 6),
- torch.ones(5).fill_(2),
- )
-
- with torch.library._scoped_library("aten", "IMPL") as lib2:
- with torch.library._scoped_library("aten", "IMPL") as lib1:
- lib1.impl(
- "arange.start",
- partial(conditional_arange_cpu1, original_kernel),
- "CPU",
- with_keyset=True,
- )
-
- self.assertEqual(torch.arange(0, 5), expected_result1)
- self.assertEqual(torch.arange(0, 6), expected_result2)
- new_original_kernel = torch.library.get_kernel(
- "aten::arange.start", "CPU"
- )
- lib2.impl(
- "arange.start",
- partial(conditional_arange_cpu2, new_original_kernel),
- "CPU",
- allow_override=True,
- with_keyset=True,
- )
-
- self.assertEqual(torch.arange(0, 5), expected_result3)
- self.assertEqual(torch.arange(0, 6), expected_result4)
- self.assertEqual(torch.arange(1, 6), expected_result5)
-
- # The kernel should now be invalidated after destroying lib1
- with self.assertRaisesRegex(RuntimeError, "has been invalidated"):
- torch.arange(0, 5)
-
- # Should still work after destroying lib1
- self.assertEqual(torch.arange(1, 6), expected_result5)
-
- def test_library_get_kernel_invalid(self):
- """Test that get_kernel raises an error when no kernel is available."""
- with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib:
- lib.define("cpu_only_op(Tensor x) -> Tensor")
- lib.impl("cpu_only_op", lambda x: x * 2, "CPU")
-
- cpu_kernel = torch.library.get_kernel(
- "test_invalid_kernel::cpu_only_op", "CPU"
- )
- self.assertIsNotNone(cpu_kernel)
-
- # CUDA should fail at the isValid() check since no CUDA kernel exists
- with self.assertRaisesRegex(
- RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op"
- ):
- torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA")
-
class MiniOpTestOther(CustomOpTestCaseBase):
test_ns = "mini_op_test"
diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py
index c44d5e5d41454..17e699e04e589 100644
--- a/test/test_flop_counter.py
+++ b/test/test_flop_counter.py
@@ -15,7 +15,6 @@
)
from torch.testing._internal.common_utils import (
run_tests,
- skipIfRocm,
TEST_WITH_TORCHDYNAMO,
TestCase,
)
@@ -463,7 +462,6 @@ def get_flops(
self.assertExpectedInline(str(flops_fw_bw_math), """805306368""")
self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""")
- @skipIfRocm # Nested tensor
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION
@@ -683,7 +681,6 @@ def split_tensor(x):
),
)
- @skipIfRocm # Nested tensor
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
diff --git a/test/test_linalg.py b/test/test_linalg.py
index 0f6c8f207421b..31d4e0d1d92d5 100644
--- a/test/test_linalg.py
+++ b/test/test_linalg.py
@@ -109,22 +109,6 @@ def get_tunableop_untuned_filename():
return untuned_filename
class TestLinalg(TestCase):
- @contextlib.contextmanager
- def _hip_allow_tf32(self):
- # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
- # and only for MI300+. Environment variable will be removed in the future.
- import os
- hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
- os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
-
- try:
- yield
- finally:
- if hip_allow_tf32 is not None:
- os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
- else:
- del os.environ["HIPBLASLT_ALLOW_TF32"]
-
def setUp(self):
super().setUp()
torch.backends.cuda.matmul.allow_tf32 = False
@@ -5542,13 +5526,8 @@ def test_scaled_gemm_tunableop(self, device, dtype):
@runOnRocmArch(MI300_ARCH)
@dtypes(torch.float)
def test_tf32_tunableop(self, device, dtype):
- # Test TunableOp with TF32. Supported by hipblasLT on MI300+.
- # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
- # and only for MI300+. Eventually this flag will go away.
- tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
-
try:
- with self._tunableop_ctx(), tf32_ctx():
+ with self._tunableop_ctx():
torch.backends.cuda.matmul.allow_tf32 = True
torch.cuda.tunable.set_rotating_buffer_size(0)
@@ -5611,13 +5590,8 @@ def test_tf32_offline_tunableop(self, device, dtype):
# This test is the offline version of test_tf32_tunableop
import os
- # Test TunableOp with TF32. Supported by hipblasLT on MI300+.
- # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
- # and only for MI300+. Eventually this flag will go away.
- tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
-
try:
- with self._tunableop_ctx(), tf32_ctx():
+ with self._tunableop_ctx():
torch.backends.cuda.matmul.allow_tf32 = True
ordinal = torch.cuda.current_device()
torch.cuda.tunable.set_rotating_buffer_size(0)
diff --git a/test/test_nn.py b/test/test_nn.py
index c17f7cb668b6f..d5c245c5887d2 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -39,7 +39,7 @@
parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \
skipIfTorchDynamo, gcIfJetson, set_default_dtype
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
- PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version
+ _get_torch_rocm_version
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \
ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input
@@ -3166,7 +3166,6 @@ def perm_fn(x):
[2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
- @skipIfRocm(msg='Large numerical errors')
def test_transformerdecoder(self):
def get_a_test_layer(use_cuda, activation, batch_first=False):
d_model = 4
@@ -12998,8 +12997,6 @@ def test_skip_init(self, device):
@dtypes(torch.float)
@dtypesIfCUDA(torch.double, torch.float, torch.half)
def test_transformerencoderlayer(self, device, dtype):
- if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
- self.skipTest("Skip on ROCM due to Flash Attention tolerances")
# this is a deterministic test for TransformerEncoderLayer
d_model = 4
nhead = 2
@@ -13221,8 +13218,6 @@ def test_transformerencoderlayer_fast_path(self, device, dtype):
@dtypes(torch.float)
@dtypesIfCUDA(torch.half, torch.float)
def test_transformerencoderlayer_gelu(self, device, dtype):
- if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half:
- self.skipTest("Skip on ROCM due to Flash Attention tolerances")
# this is a deterministic test for TransformerEncoderLayer with gelu activation
d_model = 4
nhead = 2
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 5b240e1f046c9..b2a3959a50429 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -51,7 +51,6 @@
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
tf32_on_and_off,
tf32_enabled,
- ROCM_VERSION,
)
if TEST_FAIRSEQ:
@@ -340,14 +339,11 @@ def test_train_with_pad_and_catch_error(self, device):
l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item()
self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL")
- @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
+ @tf32_on_and_off(0.001)
@parametrize("attn_mask_dim", [2, 3, None])
@parametrize("key_padding_mask_dim", [2, None])
@parametrize("mask_dtype", [torch.bool, torch.float32])
def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype):
- if TEST_WITH_ROCM:
- if attn_mask_dim is not None and mask_dtype == torch.bool:
- self.skipTest("boolean mask is not fully supported on ROCm yet.")
# MHA converts all
with torch.no_grad():
B = 2
@@ -430,8 +426,7 @@ def hook(module, inputs, output):
# remove hook
handle.remove()
- @skipIfRocm
- @tf32_on_and_off(0.001)
+ @tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001)
@parametrize("use_torchscript", [False])
@parametrize("enable_nested_tensor", [True, False])
@parametrize("use_autocast", [True, False])
@@ -524,7 +519,7 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste
slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0)
self.assertEqual(fastpath_output_expanded, slowpath_output)
- @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
+ @tf32_on_and_off(0.001)
@parametrize("with_no_grad", [True, False])
@parametrize("training", [True, False])
@parametrize("enable_nested_tensor", [False])
@@ -1110,7 +1105,7 @@ def forward(
return_all_hiddens=False,
)[0]
- @tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0)))
+ @tf32_on_and_off(0.003)
@parametrize("input_dim,attn_mask_dim,is_causal",
[(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True),
(4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)],
@@ -1421,7 +1416,6 @@ def ones_tensor(*shape):
_ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True)
torch.cuda.synchronize()
- @skipIfRocm # Missing EFFICIENT_ATTENTION
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
)
@@ -1714,7 +1708,7 @@ def test_unaligned_tensors(self, device):
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
q, k, v = make_tensor(), make_tensor(), make_tensor()
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
- ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
+ ctxmgr = self.assertRaises(RuntimeError)
with ctxmgr:
torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
@@ -2612,7 +2606,6 @@ def convert_flash_attn_S_to_softmax(
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
return S_converted[:, :, :seqlen_q, :seqlen_k]
- @skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_cudnn_attention_different_dk_dv(self, device):
dtype = torch.bfloat16
@@ -2636,7 +2629,6 @@ def test_cudnn_attention_different_dk_dv(self, device):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
- @skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_cudnn_attention_gqa(self, device):
batch = 4
@@ -2660,7 +2652,6 @@ def test_cudnn_attention_gqa(self, device):
self.assertEqual(output_math, output_cudnn)
- @skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_cudnn_attention_d256_heuristic(self, device):
dtype = torch.bfloat16
@@ -2691,7 +2682,6 @@ def test():
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
test()
- @skipIfRocm(msg="No cuDNN on ROCm")
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
def test_fused_attention_different_dk_dv(self, device):
dtype = torch.bfloat16
@@ -2715,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device):
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
- @skipIfRocm # No cuDNN Attention
+ @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
@unittest.skipIf(True, "broken as of cuDNN 9.10")
def test_cudnn_attention_fail_d128(self, device):
# Test that cuDNN attention dispatching correctly bails out on d > 128
@@ -2737,7 +2727,6 @@ def test_cudnn_attention_fail_d128(self, device):
with self.assertRaisesRegex(RuntimeError, "No available kernel."):
torch.nn.functional.scaled_dot_product_attention(q, k, v)
- @skipIfRocm(msg="No cuDNN on ROCm")
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_trivial_output_transpose(self, device):
# see also: https://github.com/pytorch/pytorch/issues/134001
@@ -2753,7 +2742,6 @@ def test_cudnn_attention_trivial_output_transpose(self, device):
o.backward(o)
torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3)
- @skipIfRocm # No cuDNN Attention
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_nonmodulo64seqlen(self, device):
# see also: https://github.com/pytorch/pytorch/issues/137347
@@ -2793,7 +2781,6 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device):
torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3)
- @skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_preserves_query_layout(self, device):
@@ -2823,7 +2810,6 @@ def test_attention(backend: SDPBackend, permute_order: list[list[int]]):
for permute_order in permute_orders:
test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3])
- @skipIfRocm
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system")
def test_cudnn_attention_compiles(self):
q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True)
@@ -3242,7 +3228,6 @@ def test_sdp_choice_with_determinism(self, device, warn_only):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
- @skipIfRocm
@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
diff --git a/third_party/fbgemm b/third_party/fbgemm
index 4b39c551efe15..3cefe0564a8c3 160000
--- a/third_party/fbgemm
+++ b/third_party/fbgemm
@@ -1 +1 @@
-Subproject commit 4b39c551efe15e6bbade20565b0ceb2d8ce3352d
+Subproject commit 3cefe0564a8c3de514a152d40a2b4770f2ee5be0
diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt
index 706881a8f10f6..c4a250db04836 100644
--- a/tools/linter/dictionary.txt
+++ b/tools/linter/dictionary.txt
@@ -12,6 +12,7 @@ BU
contiguities
contiguity
coo
+DEPENDEES
deser
din
dout
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 80437aa1d833e..5fe3f7e178b73 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1695,11 +1695,6 @@ class _DispatchModule:
_after_ADInplaceOrView_keyset: DispatchKeySet
_after_autograd_keyset: DispatchKeySet
-class _SafeKernelFunction:
- def call_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ...
- @property
- def op_handle(self) -> _DispatchOperatorHandle: ...
-
def _dispatch_library(
kind: str,
name: str,
@@ -1737,10 +1732,6 @@ def _dispatch_has_computed_kernel_for_dispatch_key(
name: str,
dispatch: _dispatchkey,
) -> _bool: ...
-def _dispatch_get_computed_kernel_for_dispatch_key(
- name: str,
- dispatch: _dispatchkey,
-) -> _SafeKernelFunction: ...
def _dispatch_find_dangling_impls() -> list[str]: ...
def _dispatch_get_all_op_names() -> list[str]: ...
def _dispatch_tls_set_dispatch_key_excluded(
diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py
index 417fac7b4f634..2189e44f9e246 100644
--- a/torch/_inductor/choices.py
+++ b/torch/_inductor/choices.py
@@ -232,6 +232,18 @@ def should_use_persistent_reduction(
features.reduction_numel, threshold
) # type: ignore[arg-types]
+ @staticmethod
+ def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
+ """
+ Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
+ So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
+ Strangely this is faster than a [1, RBLOCK] block in some cases.
+
+ ROCm branch change: Remove want_no_x_dim for persistent reduction.
+ Inductor benchmarks show no perf advantage and simplifies autotune flow.
+ """
+ return False
+
@staticmethod
def reduction_split_factor(
device: torch.device,
diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py
index 175ea55ec3af2..17a336cc3cf2e 100644
--- a/torch/_inductor/codegen/triton.py
+++ b/torch/_inductor/codegen/triton.py
@@ -1101,11 +1101,17 @@ def relu(x):
@staticmethod
def minimum(a, b):
- return f"triton_helpers.minimum({a}, {b})"
+ if torch.version.hip:
+ return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
+ else:
+ return f"triton_helpers.minimum({a}, {b})"
@staticmethod
def maximum(a, b):
- return f"triton_helpers.maximum({a}, {b})"
+ if torch.version.hip:
+ return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
+ else:
+ return f"triton_helpers.maximum({a}, {b})"
@staticmethod
def where(a, b, c):
@@ -1291,7 +1297,10 @@ def load_seed(name, offset):
@staticmethod
@maybe_upcast_float32()
def rsqrt(x):
- return f"libdevice.rsqrt({x})"
+ if torch.version.hip:
+ return f"tl.rsqrt({x})"
+ else:
+ return f"libdevice.rsqrt({x})"
@staticmethod
@maybe_upcast_float32()
@@ -1306,7 +1315,7 @@ def tan(x):
@staticmethod
@maybe_upcast_float32()
def tanh(x):
- return f"libdevice.tanh({x})"
+ return f"libdevice.fast_tanhf({x})"
@staticmethod
@maybe_upcast_float32()
@@ -2030,12 +2039,11 @@ def should_use_persistent_reduction(self) -> bool:
)
def want_no_x_dim(self):
- return (
- self.persistent_reduction
- and len(self.numels) == self.num_reduction_dims + 1
- and self.fixed_config
- and self.fixed_config["XBLOCK"] == 1
- )
+ """
+ ROCm branch change: Remove want_no_x_dim for persistent reduction.
+ Inductor benchmarks show no perf advantage and simplifies autotune flow.
+ """
+ return False
@property
def assert_function(self) -> str:
@@ -3789,8 +3797,9 @@ def codegen_body(self):
loop_end = (
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
)
+ num_stages = ", num_stages = 2" if torch.version.hip else ""
self.body.writeline(
- f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
+ f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
)
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)
diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py
index dc2392119cc51..94a905e4211ce 100644
--- a/torch/_inductor/codegen/triton_combo_kernel.py
+++ b/torch/_inductor/codegen/triton_combo_kernel.py
@@ -614,7 +614,7 @@ def jit_line(
if heuristics == "foreach":
heuristics_line = f"""
@triton_heuristics.foreach(
- num_warps={self.num_warps},
+ filename=__file__,
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
)
diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py
index f6921a057ba0f..cbe2bb44f6c86 100644
--- a/torch/_inductor/config.py
+++ b/torch/_inductor/config.py
@@ -1391,7 +1391,7 @@ class triton:
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
# Raise the threshold to 16 to be safe.
# We should revisit this once we understand more of the source of register spills.
- spill_threshold: int = 16
+ spill_threshold: int = 32 if torch.version.hip else 16
# Generate code containing the newer tl.make_block_ptr() API for loads/store
use_block_ptr = False
@@ -1442,6 +1442,15 @@ class triton:
os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32")
)
+ # Map for storing the amount of kernel runs with dumped imput tensors
+ # Based on hash of Triton source code to avoid bloating the folder
+ kernel_dump_occurency_map: dict[str, int] = {}
+
+ # Value for the maximum amount of runs with dumped kernel input tensors
+ # When the maximum is reached the first values get overwritten
+ # This ensures the last N runs are saved, where N is this value
+ max_kernel_dump_occurencies = 3
+
class aot_inductor:
"""
diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py
index ad7a0d56fc4b1..26b3bcf5cc5cf 100644
--- a/torch/_inductor/runtime/coordinate_descent_tuner.py
+++ b/torch/_inductor/runtime/coordinate_descent_tuner.py
@@ -3,6 +3,7 @@
import itertools
import logging
from typing import Callable, Optional, TYPE_CHECKING
+from functools import lru_cache
from .hints import TRITON_MAX_BLOCK
from .runtime_utils import red_text, triton_config_to_hashable
@@ -60,10 +61,16 @@ def get_config_max(self, prefix: str) -> int:
size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None
return min(max_block, size_hint) if size_hint is not None else max_block
+ @lru_cache(maxsize=1)
def get_warpsmax(self):
- # Currently, CUDA has a maximum of 1024 threads, so 32 is the max
- # number of warps.
- return 1024 // 32
+ # CUDA/ROCm has a maximum of 1024 threads per block
+ from torch.cuda import current_device, get_device_properties, is_available
+
+ warp_size = (
+ get_device_properties(current_device()).warp_size if is_available() else 32
+ )
+
+ return 1024 // warp_size
def cache_benchmark_result(self, config, timing):
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py
index 15b86b1b3d1ae..a1a0a792c9b84 100644
--- a/torch/_inductor/runtime/hints.py
+++ b/torch/_inductor/runtime/hints.py
@@ -13,7 +13,7 @@
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
- "X": 4096,
+ "X": 8192,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only
diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py
index 547fad5222465..ffcfb98a6bf32 100644
--- a/torch/_inductor/runtime/triton_heuristics.py
+++ b/torch/_inductor/runtime/triton_heuristics.py
@@ -34,6 +34,7 @@
from torch._environment import is_fbcode
from torch._prims_common import compute_required_storage_length
from torch.utils._ordered_set import OrderedSet
+from torch._inductor.config import triton as inuctor_triton_config
from ..triton_bundler import TritonBundler
from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict
@@ -223,6 +224,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
+def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name):
+ tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)]
+
+ run_index = 0
+
+ # Some kernels don't have path and hash stored
+ # Using only the name to differentiate between those
+ if not kernel_path:
+ kernel_hash = kernel_name
+
+ # Saving only the last N runs of the kernels to avoid bloating the folder
+ if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map:
+ run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1
+
+ if run_index >= inuctor_triton_config.max_kernel_dump_occurencies:
+ run_index = 0
+
+ inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index
+
+ # Default path for kernels with no hash
+ if not kernel_path:
+ directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs"
+ else:
+ directory_path = os.path.dirname(kernel_path)
+ directory_path = f"{directory_path}/{kernel_name}_run_{run_index}"
+ os.makedirs(directory_path, exist_ok=True)
+
+ tensor_index = 0
+ for tensor in tensor_list:
+ torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt")
+ tensor_index +=1
+
+
def check_autotune_cache(
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
@@ -367,6 +401,10 @@ def __init__(
self.dump_launch_params = (
os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1"
)
+ self.dump_launch_tensors = (
+ os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1"
+ )
+ self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",")
self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1"
@@ -838,7 +876,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
# for some (complicated) custom Triton kernels, a register-spilling
# config may yield the best latency.
if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get(
- "spill_threshold", 16
+ "spill_threshold", 32 if torch.version.hip else 16
):
log.debug(
"Skip config %s because of register spilling: %d",
@@ -1291,6 +1329,11 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
new_args, grid = self._interpret_args_grid(args, launcher.config)
_dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid)
+ if self.dump_launch_tensors:
+ # Check the kernel name if the list was provided
+ if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump):
+ _dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__)
+
# it is faster than entering and exiting a context manager, even if the context
# manager is a nullcontext.
if autograd_profiler._is_profiler_enabled:
@@ -2163,6 +2206,9 @@ def triton_config(
num_stages=1,
num_elements_per_warp=256,
min_elem_per_thread=0,
+ num_warps=None,
+ matrix_instr=None,
+ waves_per_eu=None,
) -> Config:
"""
Construct a pointwise triton config with some adjustment heuristics
@@ -2219,9 +2265,11 @@ def triton_config(
):
z *= 2
- num_warps = _num_warps(
- conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
- )
+ # Calculate num_warps if they are not hard passed to config
+ if num_warps is None:
+ num_warps = _num_warps(
+ conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1
+ )
# we are going to arrive at 2 warps only if bs was too small due to
# numel being too small. However to workaround some ptx bugs we still
# want at least 4 warps if there's enough elements per thread
@@ -2251,7 +2299,15 @@ def triton_config(
cfg["ZBLOCK"] = z
check_max_block(cfg)
check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel)
- return Config(cfg, num_warps=num_warps, num_stages=num_stages)
+ config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
+
+ if torch.version.hip:
+ if matrix_instr is not None:
+ config.kwargs["matrix_instr_nonkdim"] = matrix_instr
+ if waves_per_eu is not None:
+ config.kwargs["waves_per_eu"] = waves_per_eu
+
+ return config
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
@@ -2299,6 +2355,7 @@ def triton_config_reduction(
num_stages=1,
num_warps=None,
register_intensive=False,
+ waves_per_eu=None,
dynamic_scale_rblock=True,
) -> Config:
"""
@@ -2343,13 +2400,19 @@ def total_numel() -> int:
cfg = _get_config({"x": x, **rnumels})
check_max_block(cfg)
check_config(cfg, xnumel=size_hints["x"])
- return InductorConfig(
+ config = InductorConfig(
cfg,
num_warps=num_warps,
num_stages=num_stages,
dynamic_scale_rblock=dynamic_scale_rblock,
)
+ if torch.version.hip:
+ if waves_per_eu is not None:
+ config.kwargs["waves_per_eu"] = waves_per_eu
+
+ return config
+
def _get_config(numels: dict[str, int]) -> dict[str, int]:
"""
@@ -2360,7 +2423,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
def triton_config_tiled_reduction(
- size_hints, x, y, r, num_stages=1, register_intensive=False
+ size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None
):
"""
Construct a tile reduction triton config with some adjustment
@@ -2397,7 +2460,11 @@ def total_numel() -> int:
)
check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"])
check_max_block(cfg)
- return Config(cfg, num_warps=num_warps, num_stages=num_stages)
+ config = Config(cfg, num_warps=num_warps, num_stages=num_stages)
+ if torch.version.hip:
+ if waves_per_eu is not None:
+ config.kwargs["waves_per_eu"] = waves_per_eu
+ return config
def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]):
@@ -2486,11 +2553,38 @@ def pointwise(
triton_config_with_settings(
size_hints, bs // 2, num_elements_per_warp=64
),
+ triton_config_with_settings(
+ size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2
+ ),
+ triton_config_with_settings(
+ size_hints,
+ 4096, # wrt: better than the max_block for some kernel
+ ),
*hinted_configs,
]
+ # Additional reduction configs appended for ROCm builds
+ if torch.version.hip:
+ configs.append(
+ triton_config_with_settings(
+ size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1
+ )
+ ) # 20% improvement
+ configs += [
+ triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where?
+ triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel
+ triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1),
+ # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37,
+ # triton_poi_fused_index_put_new_zeros_45
+ # triton_poi_fused_index_put_new_zeros_49
+ # triton_poi_fused_index_put_new_zeros_54
+ triton_config_with_settings(size_hints, 128, num_warps=1, num_stages=1), # wri0: 56 us: triton_poi_fused_cat_mul_sigmoid_view_51
+ ]
if len(size_hints) == 2:
+ # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds
+ # ROCm has observed improvement by diverging here
if (
- disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE
+ disable_pointwise_autotuning(inductor_meta)
+ or (torch.version.hip is None and tile_hint == TileHint.SQUARE)
) and not (
inductor_meta.get("max_autotune")
or inductor_meta.get("max_autotune_pointwise")
@@ -2499,13 +2593,36 @@ def pointwise(
else:
configs = [
triton_config_with_settings(size_hints, 32, 32),
+ triton_config_with_settings(
+ size_hints, 64, 32
+ ), # better for some kernels
triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16
- triton_config_with_settings(size_hints, 256, 16),
+ triton_config_with_settings(size_hints, 256, 16),
triton_config_with_settings(size_hints, 16, 256),
+ triton_config_with_settings(
+ size_hints, 128, 16
+ ), # +10% for some kernels
+ triton_config_with_settings(size_hints, 128, 32), # additional 10% more
+ triton_config_with_settings(
+ size_hints, 32, 512
+ ), # +30% for some kernels
triton_config_with_settings(size_hints, bs, 1),
triton_config_with_settings(size_hints, 1, bs),
*hinted_configs,
]
+ if torch.version.hip:
+ configs += [ # add here
+ ]
+ # bypass triton_config_with_settings -> triton_config logic
+ if "x" in size_hints and "y" in size_hints:
+ configs += [
+ Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19
+ Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52
+ Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103
+ Config({"XBLOCK":64, "YBLOCK": 256}, num_warps=4), # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19
+ Config({"XBLOCK":512, "YBLOCK": 64}, num_warps=8), # wri0: 58us: triton_poi_fused_clone_53
+ ]
+
if len(size_hints) == 3:
if disable_pointwise_autotuning(inductor_meta):
configs = [triton_config_with_settings(size_hints, 16, 16, 16)]
@@ -2544,6 +2661,11 @@ def _reduction_configs(
# Convert reductions to 1D, to simplify heuristics.
rnumel = get_total_reduction_numel(size_hints)
+ # Is max autotune enabled
+ max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get(
+ "max_autotune_pointwise"
+ )
+
register_intensive = False
MAX_R0_BLOCK = 2048
loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
@@ -2572,6 +2694,7 @@ def make_config(
num_stages=1,
register_intensive=False,
dynamic_scale_rblock=True,
+ waves_per_eu=None,
):
# For 3D case with tiling scores, create an adapted version
if "y" in size_hints:
@@ -2584,6 +2707,7 @@ def make_config(
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
+ waves_per_eu=waves_per_eu,
)
else:
# For other cases, use the original function
@@ -2594,6 +2718,7 @@ def make_config(
num_warps=num_warps,
num_stages=num_stages,
register_intensive=register_intensive,
+ waves_per_eu=waves_per_eu,
dynamic_scale_rblock=dynamic_scale_rblock,
)
@@ -2674,33 +2799,45 @@ def outer_config_opt():
)
configs.append(c)
+ result_configs = []
+
# For 3d tiling, default to more autotuning initially
- if "y" in size_hints:
- pass
- elif inductor_meta.get("max_autotune") or inductor_meta.get(
- "max_autotune_pointwise"
- ):
- pass # skip all these cases
- elif reduction_hint == ReductionHint.INNER:
- return configs + [contiguous_config]
- elif reduction_hint == ReductionHint.OUTER:
- return configs + [outer_config]
- elif reduction_hint == ReductionHint.OUTER_TINY:
- return configs + [tiny_config]
- if disable_pointwise_autotuning(inductor_meta):
- return configs + [make_config(32, 128)]
-
- return configs + [
- contiguous_config,
- outer_config,
- tiny_config,
- make_config(64, 64),
- make_config(8, 512),
- # halve the XBLOCK/Rn_BLOCK compared to outer_config
- # TODO: this may only be beneficial when each iteration of the reduction
- # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
- make_config(64, 4, num_warps=8),
- ]
+ if not (max_autotune_enabled or "y" in size_hints):
+ if reduction_hint == ReductionHint.INNER:
+ result_configs = configs + [contiguous_config]
+ elif reduction_hint == ReductionHint.OUTER:
+ result_configs = configs + [outer_config]
+ elif reduction_hint == ReductionHint.OUTER_TINY:
+ result_configs = configs + [tiny_config]
+ else:
+ result_configs = configs + [make_config(32, 128)]
+ else:
+ result_configs = configs + [
+ contiguous_config,
+ outer_config,
+ tiny_config,
+ make_config(64, 64),
+ make_config(8, 512),
+ # halve the XBLOCK/Rn_BLOCK compared to outer_config
+ # TODO: this may only be beneficial when each iteration of the reduction
+ # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
+ make_config(64, 4, num_warps=8),
+ ]
+
+ if torch.version.hip:
+ result_configs.extend(
+ [
+ make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2),
+ make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1),
+ make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8
+ make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4
+ make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153
+ make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16
+ make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29
+ ]
+ )
+
+ return result_configs
def match_target_block_product(
@@ -2758,6 +2895,7 @@ def adapt_config_for_tiling(
num_stages=1,
register_intensive=False,
persistent_reduction=False,
+ waves_per_eu=None,
) -> Config:
"""
Create an adapted configuration based on tiling scores,
@@ -2776,6 +2914,7 @@ def adapt_config_for_tiling(
block_sizes["r0_"],
num_stages=num_stages,
register_intensive=register_intensive,
+ waves_per_eu=waves_per_eu,
)
@@ -2868,13 +3007,25 @@ def _persistent_reduction_configs(
):
xnumel = size_hints["x"]
rnumel = get_total_reduction_numel(size_hints)
+ loads_and_stores = inductor_meta.get("num_load", 0) + inductor_meta.get(
+ "num_store", 0
+ )
MAX_PERSISTENT_BLOCK_NUMEL = 4096
+ max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
+ inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
+ )
+
+ if torch.version.hip:
+ xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256]
+ else:
+ xblock_vals = [1, 8, 32, 128]
+
if "y" not in size_hints:
configs = [
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
- for xblock in (1, 8, 32, 128)
+ for xblock in xblock_vals
if xblock == 1
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
]
@@ -2882,7 +3033,7 @@ def _persistent_reduction_configs(
configs = []
assert "tiling_scores" in inductor_meta
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
- for target_block_size in (1, 8, 32, 64, 128):
+ for target_block_size in xblock_vals:
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
continue
@@ -2895,22 +3046,49 @@ def _persistent_reduction_configs(
)
)
+ tiny_configs = [
+ triton_config_reduction(
+ size_hints,
+ 2 * (256 // rnumel) if rnumel <= 256 else 1,
+ rnumel,
+ )
+ ]
+
# defer to more autotuning, initially
if "y" in size_hints:
pass
# TODO(jansel): we should be able to improve these heuristics
- elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
- configs = configs[:1]
- elif reduction_hint == ReductionHint.OUTER:
- configs = configs[-1:]
- elif reduction_hint == ReductionHint.OUTER_TINY:
- configs = [
- triton_config_reduction(
- size_hints,
- 2 * (256 // rnumel) if rnumel <= 256 else 1,
- rnumel,
- )
- ]
+ elif not max_autotune_enabled: # Do not filter configs when tuning
+ if reduction_hint == ReductionHint.INNER:
+ if rnumel > 1024:
+ configs = configs[:1]
+ else:
+ x_block = 8
+ if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256):
+ # If loads/stores greater than 5, a lot of register pressure
+ # rnumel < 256 means no vectorized loads if we split up r dim
+ # so xblock still needs to be larger
+ x_block = 1
+
+ configs = [
+ triton_config_reduction(
+ size_hints,
+ x_block,
+ rnumel,
+ register_intensive=True,
+ )
+ ]
+
+ elif reduction_hint == ReductionHint.OUTER:
+ configs = configs[-1:]
+ elif reduction_hint == ReductionHint.OUTER_TINY:
+ configs = tiny_configs
+ else:
+ # If autotune is enabled append tiny configs
+ for conf in tiny_configs:
+ if conf not in configs:
+ configs.append(conf)
+
for c in configs:
# we don't need Rn_BLOCK for persistent reduction
for prefix in size_hints:
@@ -3102,13 +3280,24 @@ def user_autotune(
)
-def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
+def foreach(triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton foreach kernel
"""
+ configs = []
+
+ # Naive autotuning path for num_warps
+ if disable_pointwise_autotuning(inductor_meta) and not (
+ inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise")
+ ):
+ configs.append(triton.Config({}, num_stages=1, num_warps=8))
+ else:
+ for warps in [1, 2, 4, 8]:
+ configs.append(triton.Config({}, num_stages=1, num_warps=warps))
+
return cached_autotune(
None,
- [triton.Config({}, num_stages=1, num_warps=num_warps)],
+ configs,
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
index 9728d27d4d79b..0ac2c79d1e98a 100644
--- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
+++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
@@ -260,7 +260,7 @@ typedef __half half;
)";
#endif
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && ROCM_VERSION < 70000
constexpr auto bfloat16_support_literal =
R"(
#ifndef __align__
@@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) {
return val;
}
+__device__ float __bfloat162float(const __nv_bfloat16 a) {
+ union
+ {
+ uint32_t int32;
+ float fp32;
+ } u = {uint32_t(a.__x) << 16};
+ return u.fp32;
+}
+#endif /* defined(__cplusplus) */
+)";
+#elif defined(USE_ROCM) && ROCM_VERSION >= 70000
+constexpr auto bfloat16_support_literal =
+ R"(
+#ifndef __align__
+#define __align__(x) __attribute__((aligned(x)))
+#endif
+
+typedef unsigned int uint32_t;
+
+typedef struct __align__(2) {
+ unsigned short x;
+}
+__nv_bfloat16_raw;
+
+#if defined(__cplusplus)
+struct __align__(2) __nv_bfloat16 {
+ __host__ __device__ __nv_bfloat16() {}
+
+ __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) {
+ __x = hr.x;
+ return *this;
+ }
+
+ unsigned short __x;
+};
+
+__device__ unsigned short __internal_float2bfloat16(
+ const float f,
+ unsigned int& sign,
+ unsigned int& remainder) {
+ unsigned int x;
+
+ x = __float_as_uint(f);
+
+ if ((x & 0x7fffffffU) > 0x7f800000U) {
+ sign = 0U;
+ remainder = 0U;
+ return static_cast(0x7fffU);
+ }
+ sign = x >> 31;
+ remainder = x << 16;
+ return static_cast(x >> 16);
+}
+
+/* Definitions of intrinsics */
+__device__ __nv_bfloat16 __float2bfloat16(const float a) {
+ __nv_bfloat16 val;
+ __nv_bfloat16_raw r;
+ unsigned int sign;
+ unsigned int remainder;
+ r.x = __internal_float2bfloat16(a, sign, remainder);
+ if ((remainder > 0x80000000U) ||
+ ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) {
+ r.x++;
+ }
+ val = r;
+ return val;
+}
+
__device__ float __bfloat162float(const __nv_bfloat16 a) {
union
{
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 9d6eb35c71789..07fa4ea5e1dd7 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -620,43 +620,6 @@ void initDispatchBindings(PyObject* module) {
c10::parseDispatchKey(dispatch));
});
- // Bind SafeKernelFunction class
- py::class_(m, "_SafeKernelFunction")
- .def(
- "call_boxed",
- [](const c10::SafeKernelFunction& self,
- c10::DispatchKeySet keyset,
- py::args args,
- const py::kwargs& kwargs) {
- const auto& op = self.opHandle();
- auto stack = torch::jit::createStackForSchema(
- op.schema(),
- std::move(args),
- kwargs,
- /*self=*/std::nullopt);
- self.callBoxed(op, keyset, &stack);
- return torch::jit::createPyObjectForStack(std::move(stack));
- })
- .def(
- "__repr__",
- [](const c10::SafeKernelFunction& self) {
- return "SafeKernelFunction(debug='" + self.debug() + "')";
- })
- .def_property_readonly(
- "op_handle", [](const c10::SafeKernelFunction& self) -> py::object {
- return py::cast(self.opHandle());
- });
-
- m.def(
- "_dispatch_get_computed_kernel_for_dispatch_key",
- [](const char* name,
- c10::DispatchKey dispatch) -> c10::SafeKernelFunction {
- auto op =
- c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
- TORCH_CHECK(op, "operator ", name, " does not exist");
- return op->getComputedKernelForDispatchKey(dispatch);
- });
-
m.def("_dispatch_find_dangling_impls", []() -> std::vector {
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py
index c3982c33315e2..d1ac7fad7480b 100644
--- a/torch/cuda/tunable.py
+++ b/torch/cuda/tunable.py
@@ -591,7 +591,6 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None:
transA = layout[1] == "T"
dtype = dtype_dict.get(data_type)
if data_type == "tf32":
- # User must still set HIPBLASLT_ALLOW_TF32=1
torch.backends.cuda.matmul.allow_tf32 = True
else:
torch.backends.cuda.matmul.allow_tf32 = False
diff --git a/torch/library.py b/torch/library.py
index d36c181581483..372037f09dbe5 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -45,7 +45,6 @@
"register_torch_dispatch",
"register_vmap",
"get_ctx",
- "get_kernel",
"custom_op",
"triton_op",
"wrap_triton",
@@ -1476,80 +1475,6 @@ def get_ctx() -> "torch._library.fake_impl.FakeImplCtx":
return torch._library.fake_impl.global_ctx_getter()
-def get_kernel(
- op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey]
-) -> torch._C._SafeKernelFunction:
- """Returns the computed kernel for a given operator and dispatch key.
-
- This function retrieves the kernel that would be executed for a given
- operator and dispatch key combination. The returned SafeKernelFunction
- can be used to call the kernel in a boxed fashion. The intended use
- case for this function is to retrieve the original kernel for a given
- dispatch key and then register another kernel to the same dispatch key
- that calls into the original kernel for certain cases.
-
- Args:
- op: Operator name (along with the overload) or OpOverload object
- Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef.
- dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for.
- Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value.
-
- Returns:
- torch._C._SafeKernelFunction: A safe kernel function that can be used to
- call the kernel.
-
- Raises:
- RuntimeError: If the operator does not exist.
-
- Example:
- >>> # Get the CPU kernel for torch.add
- >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU")
- >>>
- >>> # You can also use DispatchKey enum
- >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU)
- >>>
- >>> # Or use an OpOverload directly
- >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU")
- >>>
- >>> # Example: Using get_kernel in a custom op with conditional dispatch
- >>> # Get the original kernel for torch.sin
- >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU")
- >>>
- >>> # If input has negative values, use original sin, otherwise return zeros
- >>> def conditional_sin_impl(dispatch_keys, x):
- >>> if (x < 0).any():
- >>> return original_sin_kernel.call_boxed(dispatch_keys, x)
- >>> else:
- >>> return torch.zeros_like(x)
- >>>
- >>> lib = torch.library.Library("aten", "IMPL")
- >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet
- >>> which needs to be the first argument to ``kernel.call_boxed``
- >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True)
- >>>
- >>> # Test the conditional behavior
- >>> x_positive = torch.tensor([1.0, 2.0])
- >>> x_mixed = torch.tensor([-1.0, 2.0])
- >>> torch.sin(x_positive)
- tensor([0., 0.])
- >>> torch.sin(x_mixed)
- tensor([-0.8415, 0.9093])
- """
- if not isinstance(op, (str, torch._ops.OpOverload)):
- raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}")
-
- if isinstance(op, torch._ops.OpOverload):
- op = op._name
-
- if isinstance(dispatch_key, str):
- try:
- dispatch_key = torch._C.DispatchKey.__members__[dispatch_key]
- except KeyError:
- raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None
-
- return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key)
-
-
_OPCHECK_DEFAULT_UTILS = (
"test_schema",
"test_autograd_registration",
diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py
index be284429114f5..846d2b407684c 100644
--- a/torch/testing/_internal/common_cuda.py
+++ b/torch/testing/_internal/common_cuda.py
@@ -181,9 +181,6 @@ def tf32_off():
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
- if torch.version.hip:
- hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
- os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
old_precision = self.precision
try:
@@ -192,11 +189,6 @@ def tf32_on(self, tf32_precision=1e-5):
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
yield
finally:
- if torch.version.hip:
- if hip_allow_tf32 is not None:
- os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
- else:
- del os.environ["HIPBLASLT_ALLOW_TF32"]
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
self.precision = old_precision
@@ -246,7 +238,7 @@ def tf32_enabled():
# if device is specified, it will check if device is cuda
# if dtype is specified, it will check if dtype is float32 or complex64
# tf32 and fp32 are different only when all the three checks pass
-def tf32_on_and_off(tf32_precision=1e-5, only_if=True):
+def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True):
def with_tf32_disabled(self, function_call):
with tf32_off():
function_call()
diff --git a/version.txt b/version.txt
index 03e905f0db5fe..c8e38b614057b 100644
--- a/version.txt
+++ b/version.txt
@@ -1 +1 @@
-2.9.0a0
+2.9.0