diff --git a/.automation_scripts/parse_xml_results.py b/.automation_scripts/parse_xml_results.py new file mode 100644 index 0000000000000..7db2e1ce9233c --- /dev/null +++ b/.automation_scripts/parse_xml_results.py @@ -0,0 +1,178 @@ +""" The Python PyTorch testing script. +## +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" + +import xml.etree.ElementTree as ET +from pathlib import Path +from typing import Any, Dict, Tuple + +# Backends list +BACKENDS_LIST = [ + "dist-gloo", + "dist-nccl" +] + +TARGET_WORKFLOW = "--rerun-disabled-tests" + +def get_job_id(report: Path) -> int: + # [Job id in artifacts] + # Retrieve the job id from the report path. In our GHA workflows, we append + # the job id to the end of the report name, so `report` looks like: + # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml + # and we want to get `5596745227` out of it. + try: + return int(report.parts[0].rpartition("_")[2]) + except ValueError: + return -1 + +def is_rerun_disabled_tests(root: ET.ElementTree) -> bool: + """ + Check if the test report is coming from rerun_disabled_tests workflow + """ + skipped = root.find(".//*skipped") + # Need to check against None here, if not skipped doesn't work as expected + if skipped is None: + return False + + message = skipped.attrib.get("message", "") + return TARGET_WORKFLOW in message or "num_red" in message + +def parse_xml_report( + tag: str, + report: Path, + workflow_id: int, + workflow_run_attempt: int, + work_flow_name: str +) -> Dict[Tuple[str], Dict[str, Any]]: + """Convert a test report xml file into a JSON-serializable list of test cases.""" + print(f"Parsing {tag}s for test report: {report}") + + job_id = get_job_id(report) + print(f"Found job id: {job_id}") + + test_cases: Dict[Tuple[str], Dict[str, Any]] = {} + + root = ET.parse(report) + # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops + # includes skipped messages multiple times (50 times by default). This slows down + # this script too much (O(n)) because it tries to gather all the stats. This should + # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun + # disabled test is only few MB, but will balloon up to a much bigger XML file after + # extracting from a dozen to few hundred MB + if is_rerun_disabled_tests(root): + return test_cases + + for test_case in root.iter(tag): + case = process_xml_element(test_case) + if tag == 'testcase': + case["workflow_id"] = workflow_id + case["workflow_run_attempt"] = workflow_run_attempt + case["job_id"] = job_id + case["work_flow_name"] = work_flow_name + + # [invoking file] + # The name of the file that the test is located in is not necessarily + # the same as the name of the file that invoked the test. + # For example, `test_jit.py` calls into multiple other test files (e.g. + # jit/test_dce.py). For sharding/test selection purposes, we want to + # record the file that invoked the test. + # + # To do this, we leverage an implementation detail of how we write out + # tests (https://bit.ly/3ajEV1M), which is that reports are created + # under a folder with the same name as the invoking file. + case_name = report.parent.name + for ind in range(len(BACKENDS_LIST)): + if BACKENDS_LIST[ind] in report.parts: + case_name = case_name + "_" + BACKENDS_LIST[ind] + break + case["invoking_file"] = case_name + test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["work_flow_name"] ) ] = case + elif tag == 'testsuite': + case["work_flow_name"] = work_flow_name + case["invoking_xml"] = report.name + case["running_time_xml"] = case["time"] + case_name = report.parent.name + for ind in range(len(BACKENDS_LIST)): + if BACKENDS_LIST[ind] in report.parts: + case_name = case_name + "_" + BACKENDS_LIST[ind] + break + case["invoking_file"] = case_name + + test_cases[ ( case["invoking_file"], case["invoking_xml"], case["work_flow_name"] ) ] = case + + return test_cases + +def process_xml_element(element: ET.Element) -> Dict[str, Any]: + """Convert a test suite element into a JSON-serializable dict.""" + ret: Dict[str, Any] = {} + + # Convert attributes directly into dict elements. + # e.g. + # + # becomes: + # {"name": "test_foo", "classname": "test_bar"} + ret.update(element.attrib) + + # The XML format encodes all values as strings. Convert to ints/floats if + # possible to make aggregation possible in Rockset. + for k, v in ret.items(): + try: + ret[k] = int(v) + except ValueError: + pass + try: + ret[k] = float(v) + except ValueError: + pass + + # Convert inner and outer text into special dict elements. + # e.g. + # my_inner_text my_tail + # becomes: + # {"text": "my_inner_text", "tail": " my_tail"} + if element.text and element.text.strip(): + ret["text"] = element.text + if element.tail and element.tail.strip(): + ret["tail"] = element.tail + + # Convert child elements recursively, placing them at a key: + # e.g. + # + # hello + # world + # another + # + # becomes + # { + # "foo": [{"text": "hello"}, {"text": "world"}], + # "bar": {"text": "another"} + # } + for child in element: + if child.tag not in ret: + ret[child.tag] = process_xml_element(child) + else: + # If there are multiple tags with the same name, they should be + # coalesced into a list. + if not isinstance(ret[child.tag], list): + ret[child.tag] = [ret[child.tag]] + ret[child.tag].append(process_xml_element(child)) + return ret \ No newline at end of file diff --git a/.automation_scripts/run_pytorch_unit_tests.py b/.automation_scripts/run_pytorch_unit_tests.py new file mode 100644 index 0000000000000..514afd19624c3 --- /dev/null +++ b/.automation_scripts/run_pytorch_unit_tests.py @@ -0,0 +1,518 @@ +#!/usr/bin/env python3 + +""" The Python PyTorch testing script. +## +# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" + +import argparse +import os +import shutil +import subprocess +from subprocess import STDOUT, CalledProcessError + +from collections import namedtuple +from datetime import datetime +from pathlib import Path +from parse_xml_results import ( + parse_xml_report +) +from pprint import pprint +from typing import Any, Dict, List + +# unit test status list +UT_STATUS_LIST = [ + "PASSED", + "MISSED", + "SKIPPED", + "FAILED", + "XFAILED", + "ERROR" +] + +DEFAULT_CORE_TESTS = [ + "test_nn", + "test_torch", + "test_cuda", + "test_ops", + "test_unary_ufuncs", + "test_autograd", + "inductor/test_torchinductor" +] + +DISTRIBUTED_CORE_TESTS = [ + "distributed/test_c10d_common", + "distributed/test_c10d_nccl", + "distributed/test_distributed_spawn" +] + +CONSOLIDATED_LOG_FILE_NAME="pytorch_unit_tests.log" + +def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, workflow_name, path="."): + test_cases = {} + items_list = os.listdir(path) + for dir in items_list: + new_dir = path + '/' + dir + '/' + if os.path.isdir(new_dir): + for xml_report in Path(new_dir).glob("**/*.xml"): + test_cases.update( + parse_xml_report( + tag, + xml_report, + workflow_run_id, + workflow_run_attempt, + workflow_name + ) + ) + return test_cases + +def get_test_status(test_case): + # In order of priority: S=skipped, F=failure, E=error, P=pass + if "skipped" in test_case and test_case["skipped"]: + type_message = test_case["skipped"] + if type_message.__contains__('type') and type_message['type'] == "pytest.xfail": + return "XFAILED" + else: + return "SKIPPED" + elif "failure" in test_case and test_case["failure"]: + return "FAILED" + elif "error" in test_case and test_case["error"]: + return "ERROR" + else: + return "PASSED" + +def get_test_message(test_case, status=None): + if status == "SKIPPED": + return test_case["skipped"] if "skipped" in test_case else "" + elif status == "FAILED": + return test_case["failure"] if "failure" in test_case else "" + elif status == "ERROR": + return test_case["error"] if "error" in test_case else "" + else: + if "skipped" in test_case: + return test_case["skipped"] + elif "failure" in test_case: + return test_case["failure"] + elif "error" in test_case: + return test_case["error"] + else: + return "" + +def get_test_file_running_time(test_suite): + if test_suite.__contains__('time'): + return test_suite["time"] + return 0 + +def get_test_running_time(test_case): + if test_case.__contains__('time'): + return test_case["time"] + return "" + +def summarize_xml_files(path, workflow_name): + # statistics + TOTAL_TEST_NUM = 0 + TOTAL_PASSED_NUM = 0 + TOTAL_SKIPPED_NUM = 0 + TOTAL_XFAIL_NUM = 0 + TOTAL_FAILED_NUM = 0 + TOTAL_ERROR_NUM = 0 + TOTAL_EXECUTION_TIME = 0 + + #parse the xml files + test_cases = parse_xml_reports_as_dict(-1, -1, 'testcase', workflow_name, path) + test_suites = parse_xml_reports_as_dict(-1, -1, 'testsuite', workflow_name, path) + test_file_and_status = namedtuple("test_file_and_status", ["file_name", "status"]) + # results dict + res = {} + res_item_list = [ "PASSED", "SKIPPED", "XFAILED", "FAILED", "ERROR" ] + test_file_items = set() + for (k,v) in list(test_suites.items()): + file_name = k[0] + if not file_name in test_file_items: + test_file_items.add(file_name) + # initialization + for item in res_item_list: + temp_item = test_file_and_status(file_name, item) + res[temp_item] = {} + temp_item_statistics = test_file_and_status(file_name, "STATISTICS") + res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0, 'EXECUTION_TIME': 0} + test_running_time = get_test_file_running_time(v) + res[temp_item_statistics]["EXECUTION_TIME"] += test_running_time + TOTAL_EXECUTION_TIME += test_running_time + else: + test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS") + test_running_time = get_test_file_running_time(v) + res[test_tuple_key_statistics]["EXECUTION_TIME"] += test_running_time + TOTAL_EXECUTION_TIME += test_running_time + + for (k,v) in list(test_cases.items()): + file_name = k[0] + class_name = k[1] + test_name = k[2] + combined_name = file_name + "::" + class_name + "::" + test_name + test_status = get_test_status(v) + test_running_time = get_test_running_time(v) + test_message = get_test_message(v, test_status) + test_info_value = "" + test_tuple_key_status = test_file_and_status(file_name, test_status) + test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS") + TOTAL_TEST_NUM += 1 + res[test_tuple_key_statistics]["TOTAL"] += 1 + if test_status == "PASSED": + test_info_value = str(test_running_time) + res[test_tuple_key_status][combined_name] = test_info_value + res[test_tuple_key_statistics]["PASSED"] += 1 + TOTAL_PASSED_NUM += 1 + elif test_status == "SKIPPED": + test_info_value = str(test_running_time) + res[test_tuple_key_status][combined_name] = test_info_value + res[test_tuple_key_statistics]["SKIPPED"] += 1 + TOTAL_SKIPPED_NUM += 1 + elif test_status == "XFAILED": + test_info_value = str(test_running_time) + res[test_tuple_key_status][combined_name] = test_info_value + res[test_tuple_key_statistics]["XFAILED"] += 1 + TOTAL_XFAIL_NUM += 1 + elif test_status == "FAILED": + test_info_value = test_message + res[test_tuple_key_status][combined_name] = test_info_value + res[test_tuple_key_statistics]["FAILED"] += 1 + TOTAL_FAILED_NUM += 1 + elif test_status == "ERROR": + test_info_value = test_message + res[test_tuple_key_status][combined_name] = test_info_value + res[test_tuple_key_statistics]["ERROR"] += 1 + TOTAL_ERROR_NUM += 1 + + # generate statistics_dict + statistics_dict = {} + statistics_dict["TOTAL"] = TOTAL_TEST_NUM + statistics_dict["PASSED"] = TOTAL_PASSED_NUM + statistics_dict["SKIPPED"] = TOTAL_SKIPPED_NUM + statistics_dict["XFAILED"] = TOTAL_XFAIL_NUM + statistics_dict["FAILED"] = TOTAL_FAILED_NUM + statistics_dict["ERROR"] = TOTAL_ERROR_NUM + statistics_dict["EXECUTION_TIME"] = TOTAL_EXECUTION_TIME + aggregate_item = workflow_name + "_aggregate" + total_item = test_file_and_status(aggregate_item, "STATISTICS") + res[total_item] = statistics_dict + + return res + +def run_command_and_capture_output(cmd): + try: + print(f"Running command '{cmd}'") + with open(CONSOLIDATED_LOG_FILE_PATH, "a+") as output_file: + print(f"========================================", file=output_file, flush=True) + print(f"[RUN_PYTORCH_UNIT_TESTS] Running command '{cmd}'", file=output_file, flush=True) # send to consolidated file as well + print(f"========================================", file=output_file, flush=True) + p = subprocess.run(cmd, shell=True, stdout=output_file, stderr=STDOUT, text=True) + except CalledProcessError as e: + print(f"ERROR: Cmd {cmd} failed with return code: {e.returncode}!") + +def run_entire_tests(workflow_name, test_shell_path, overall_logs_path_current_run, test_reports_src): + if os.path.exists(test_reports_src): + shutil.rmtree(test_reports_src) + + os.mkdir(test_reports_src) + copied_logs_path = "" + if workflow_name == "default": + os.environ['TEST_CONFIG'] = 'default' + copied_logs_path = overall_logs_path_current_run + "default_xml_results_entire_tests/" + elif workflow_name == "distributed": + os.environ['TEST_CONFIG'] = 'distributed' + copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_entire_tests/" + elif workflow_name == "inductor": + os.environ['TEST_CONFIG'] = 'inductor' + copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_entire_tests/" + # use test.sh for tests execution + run_command_and_capture_output(test_shell_path) + copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path) + entire_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name) + return entire_results_dict + +def run_priority_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src): + if os.path.exists(test_reports_src): + shutil.rmtree(test_reports_src) + + os.mkdir(test_reports_src) + copied_logs_path = "" + if workflow_name == "default": + os.environ['TEST_CONFIG'] = 'default' + os.environ['HIP_VISIBLE_DEVICES'] = '0' + copied_logs_path = overall_logs_path_current_run + "default_xml_results_priority_tests/" + # use run_test.py for tests execution + default_priority_test_suites = " ".join(DEFAULT_CORE_TESTS) + command = "python3 " + test_run_test_path + " --include " + default_priority_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose" + run_command_and_capture_output(command) + del os.environ['HIP_VISIBLE_DEVICES'] + elif workflow_name == "distributed": + os.environ['TEST_CONFIG'] = 'distributed' + os.environ['HIP_VISIBLE_DEVICES'] = '0,1' + copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_priority_tests/" + # use run_test.py for tests execution + distributed_priority_test_suites = " ".join(DISTRIBUTED_CORE_TESTS) + command = "python3 " + test_run_test_path + " --include " + distributed_priority_test_suites + " --distributed-tests --verbose" + run_command_and_capture_output(command) + del os.environ['HIP_VISIBLE_DEVICES'] + copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path) + priority_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name) + + return priority_results_dict + +def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src, selected_list): + if os.path.exists(test_reports_src): + shutil.rmtree(test_reports_src) + + os.mkdir(test_reports_src) + copied_logs_path = "" + if workflow_name == "default": + os.environ['TEST_CONFIG'] = 'default' + os.environ['HIP_VISIBLE_DEVICES'] = '0' + copied_logs_path = overall_logs_path_current_run + "default_xml_results_selected_tests/" + # use run_test.py for tests execution + default_selected_test_suites = " ".join(selected_list) + command = "python3 " + test_run_test_path + " --include " + default_selected_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose" + run_command_and_capture_output(command) + del os.environ['HIP_VISIBLE_DEVICES'] + elif workflow_name == "distributed": + os.environ['TEST_CONFIG'] = 'distributed' + os.environ['HIP_VISIBLE_DEVICES'] = '0,1' + copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_selected_tests/" + # use run_test.py for tests execution + distributed_selected_test_suites = " ".join(selected_list) + command = "python3 " + test_run_test_path + " --include " + distributed_selected_test_suites + " --distributed-tests --verbose" + run_command_and_capture_output(command) + del os.environ['HIP_VISIBLE_DEVICES'] + elif workflow_name == "inductor": + os.environ['TEST_CONFIG'] = 'inductor' + copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_selected_tests/" + inductor_selected_test_suites = "" + non_inductor_selected_test_suites = "" + for item in selected_list: + if "inductor/" in item: + inductor_selected_test_suites += item + inductor_selected_test_suites += " " + else: + non_inductor_selected_test_suites += item + non_inductor_selected_test_suites += " " + if inductor_selected_test_suites != "": + inductor_selected_test_suites = inductor_selected_test_suites[:-1] + command = "python3 " + test_run_test_path + " --include " + inductor_selected_test_suites + " --verbose" + run_command_and_capture_output(command) + if non_inductor_selected_test_suites != "": + non_inductor_selected_test_suites = non_inductor_selected_test_suites[:-1] + command = "python3 " + test_run_test_path + " --inductor --include " + non_inductor_selected_test_suites + " --verbose" + run_command_and_capture_output(command) + copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path) + selected_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name) + + return selected_results_dict + +def run_test_and_summarize_results( + pytorch_root_dir: str, + priority_tests: bool, + test_config: List[str], + default_list: List[str], + distributed_list: List[str], + inductor_list: List[str], + skip_rerun: bool) -> Dict[str, Any]: + + # copy current environment variables + _environ = dict(os.environ) + + # modify path + test_shell_path = pytorch_root_dir + "/.ci/pytorch/test.sh" + test_run_test_path = pytorch_root_dir + "/test/run_test.py" + repo_test_log_folder_path = pytorch_root_dir + "/.automation_logs/" + test_reports_src = pytorch_root_dir + "/test/test-reports/" + run_test_python_file = pytorch_root_dir + "/test/run_test.py" + + # change directory to pytorch root + os.chdir(pytorch_root_dir) + + # all test results dict + res_all_tests_dict = {} + + # patterns + search_text = "--reruns=2" + replace_text = "--reruns=0" + + # create logs folder + if not os.path.exists(repo_test_log_folder_path): + os.mkdir(repo_test_log_folder_path) + + # Set common environment variables for all scenarios + os.environ['CI'] = '1' + os.environ['PYTORCH_TEST_WITH_ROCM'] = '1' + os.environ['HSA_FORCE_FINE_GRAIN_PCIE'] = '1' + os.environ['PYTORCH_TESTING_DEVICE_ONLY_FOR'] = 'cuda' + os.environ['CONTINUE_THROUGH_ERROR'] = 'True' + if skip_rerun: + # modify run_test.py in-place + with open(run_test_python_file, 'r') as file: + data = file.read() + data = data.replace(search_text, replace_text) + with open(run_test_python_file, 'w') as file: + file.write(data) + + # Time stamp + current_datetime = datetime.now().strftime("%Y%m%d_%H-%M-%S") + print("Current date & time : ", current_datetime) + # performed as Job ID + str_current_datetime = str(current_datetime) + overall_logs_path_current_run = repo_test_log_folder_path + str_current_datetime + "/" + os.mkdir(overall_logs_path_current_run) + + global CONSOLIDATED_LOG_FILE_PATH + CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME + + # Check multi gpu availability if distributed tests are enabled + if ("distributed" in test_config) or len(distributed_list) != 0: + check_num_gpus_for_distributed() + + # Install test requirements + command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt" + run_command_and_capture_output(command) + + # Run entire tests for each workflow + if not priority_tests and not default_list and not distributed_list and not inductor_list: + # run entire tests for default, distributed and inductor workflows → use test.sh + if not test_config: + check_num_gpus_for_distributed() + # default test process + res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["default"] = res_default_all + # distributed test process + res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["distributed"] = res_distributed_all + # inductor test process + res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["inductor"] = res_inductor_all + else: + workflow_list = [] + for item in test_config: + workflow_list.append(item) + if "default" in workflow_list: + res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["default"] = res_default_all + if "distributed" in workflow_list: + res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["distributed"] = res_distributed_all + if "inductor" in workflow_list: + res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["inductor"] = res_inductor_all + # Run priority test for each workflow + elif priority_tests and not default_list and not distributed_list and not inductor_list: + if not test_config: + check_num_gpus_for_distributed() + # default test process + res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["default"] = res_default_priority + # distributed test process + res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["distributed"] = res_distributed_priority + # will not run inductor priority tests + print("Inductor priority tests cannot run since no core tests defined with inductor workflow.") + else: + workflow_list = [] + for item in test_config: + workflow_list.append(item) + if "default" in workflow_list: + res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["default"] = res_default_priority + if "distributed" in workflow_list: + res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src) + res_all_tests_dict["distributed"] = res_distributed_priority + if "inductor" in workflow_list: + print("Inductor priority tests cannot run since no core tests defined with inductor workflow.") + # Run specified tests for each workflow + elif (default_list or distributed_list or inductor_list) and not test_config and not priority_tests: + if default_list: + default_workflow_list = [] + for item in default_list: + default_workflow_list.append(item) + res_default_selected = run_selected_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src, default_workflow_list) + res_all_tests_dict["default"] = res_default_selected + if distributed_list: + distributed_workflow_list = [] + for item in distributed_list: + distributed_workflow_list.append(item) + res_distributed_selected = run_selected_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src, distributed_workflow_list) + res_all_tests_dict["distributed"] = res_distributed_selected + if inductor_list: + inductor_workflow_list = [] + for item in inductor_list: + inductor_workflow_list.append(item) + res_inductor_selected = run_selected_tests("inductor", test_run_test_path, overall_logs_path_current_run, test_reports_src, inductor_workflow_list) + res_all_tests_dict["inductor"] = res_inductor_selected + else: + raise Exception("Invalid test configurations!") + + # restore environment variables + os.environ.clear() + os.environ.update(_environ) + + # restore files + if skip_rerun: + # modify run_test.py in-place + with open(run_test_python_file, 'r') as file: + data = file.read() + data = data.replace(replace_text, search_text) + with open(run_test_python_file, 'w') as file: + file.write(data) + + return res_all_tests_dict + +def parse_args(): + parser = argparse.ArgumentParser(description='Run PyTorch unit tests and generate xml results summary', formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--test_config', nargs='+', default=[], type=str, help="space-separated list of test workflows to be executed eg. 'default distributed'") + parser.add_argument('--priority_tests', action='store_true', help="run priority tests only") + parser.add_argument('--default_list', nargs='+', default=[], help="space-separated list of 'default' config test suites/files to be executed eg. 'test_weak test_dlpack'") + parser.add_argument('--distributed_list', nargs='+', default=[], help="space-separated list of 'distributed' config test suites/files to be executed eg. 'distributed/test_c10d_common distributed/test_c10d_nccl'") + parser.add_argument('--inductor_list', nargs='+', default=[], help="space-separated list of 'inductor' config test suites/files to be executed eg. 'inductor/test_torchinductor test_ops'") + parser.add_argument('--pytorch_root', default='.', type=str, help="PyTorch root directory") + parser.add_argument('--skip_rerun', action='store_true', help="skip rerun process") + parser.add_argument('--example_output', type=str, help="{'workflow_name': {\n" + " test_file_and_status(file_name='workflow_aggregate', status='STATISTICS'): {}, \n" + " test_file_and_status(file_name='test_file_name_1', status='ERROR'): {}, \n" + " test_file_and_status(file_name='test_file_name_1', status='FAILED'): {}, \n" + " test_file_and_status(file_name='test_file_name_1', status='PASSED'): {}, \n" + " test_file_and_status(file_name='test_file_name_1', status='SKIPPED'): {}, \n" + " test_file_and_status(file_name='test_file_name_1', status='STATISTICS'): {} \n" + "}}\n") + parser.add_argument('--example_usages', type=str, help="RUN ALL TESTS: python3 run_pytorch_unit_tests.py \n" + "RUN PRIORITY TESTS: python3 run_pytorch_unit_tests.py --test_config distributed --priority_test \n" + "RUN SELECTED TESTS: python3 run_pytorch_unit_tests.py --default_list test_weak test_dlpack --inductor_list inductor/test_torchinductor") + return parser.parse_args() + +def check_num_gpus_for_distributed(): + p = subprocess.run("rocminfo | grep -cE 'Name:\s+gfx'", shell=True, capture_output=True, text=True) + num_gpus_visible = int(p.stdout) + assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests" + +def main(): + args = parse_args() + all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun) + pprint(dict(all_tests_results)) + +if __name__ == "__main__": + main() diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 71d3ef714fbe1..a1e9df4725c57 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -bbb06c0334a6772b92d24bde54956e675c8c6604 +d704bc6e69c1a588c8edd3cbb67505d554ed65f6 diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index f48140952c3ac..8e714bcb6cd32 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -21,7 +21,7 @@ elif [ -n "${TRITON_CPU}" ]; then TRITON_REPO="https://github.com/triton-lang/triton-cpu" TRITON_TEXT_FILE="triton-cpu" else - TRITON_REPO="https://github.com/triton-lang/triton" + TRITON_REPO="https://github.com/ROCm/triton" TRITON_TEXT_FILE="triton" fi diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index 583136d7df2f1..d3da8e93c639a 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -10,7 +10,7 @@ boto3==1.35.42 #Pinned versions: 1.19.12, 1.16.34 #test that import: -click +click==8.3.0 #Description: Command Line Interface Creation Kit #Pinned versions: #test that import: @@ -63,7 +63,7 @@ lark==0.12.0 #Pinned versions: 0.12.0 #test that import: -librosa>=0.6.2 ; python_version < "3.11" and platform_machine != "s390x" +librosa==0.11.0 ; python_version < "3.11" and platform_machine != "s390x" librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x" #Description: A python package for music and audio analysis #Pinned versions: >=0.6.2 @@ -113,9 +113,8 @@ ninja==1.11.1.3 #test that import: run_test.py, test_cpp_extensions_aot.py,test_determination.py numba==0.49.0 ; python_version < "3.9" and platform_machine != "s390x" -numba==0.55.2 ; python_version == "3.9" and platform_machine != "s390x" -numba==0.55.2 ; python_version == "3.10" and platform_machine != "s390x" -numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x" +numba==0.60.0 ; python_version == "3.9" and platform_machine != "s390x" +numba==0.61.2 ; python_version > "3.9" and platform_machine != "s390x" #Description: Just-In-Time Compiler for Numerical Functions #Pinned versions: 0.54.1, 0.49.0, <=0.49.1 #test that import: test_numba_integration.py @@ -134,12 +133,10 @@ numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x" #test_nn.py, test_namedtensor.py, test_linalg.py, test_jit_cuda_fuser.py, #test_jit.py, test_indexing.py, test_datapipe.py, test_dataloader.py, #test_binary_ufuncs.py -numpy==1.22.4; python_version == "3.9" or python_version == "3.10" -numpy==1.26.2; python_version == "3.11" or python_version == "3.12" -numpy==2.1.2; python_version >= "3.13" +numpy==2.0.2 ; python_version == "3.9" +numpy==2.1.2 ; python_version > "3.9" -pandas==2.0.3; python_version < "3.13" -pandas==2.2.3; python_version >= "3.13" +pandas==2.2.3 #onnxruntime #Description: scoring engine for Open Neural Network Exchange (ONNX) models @@ -169,10 +166,11 @@ pillow==11.0.0 #Pinned versions: 10.3.0 #test that import: -protobuf==5.29.4 -#Description: Google's data interchange format -#Pinned versions: 5.29.4 -#test that import: test_tensorboard.py, test/onnx/* +protobuf==3.20.2 ; python_version <= "3.12" +protobuf==4.25.1 ; python_version == "3.13" +#Description: Google’s data interchange format +#Pinned versions: 3.20.1 +#test that import: test_tensorboard.py psutil #Description: information on running processes and system utilization @@ -194,7 +192,7 @@ pytest-flakefinder==1.1.0 #Pinned versions: 1.1.0 #test that import: -pytest-rerunfailures>=10.3 +pytest-rerunfailures==14.0 #Description: plugin for rerunning failure tests in pytest #Pinned versions: #test that import: @@ -250,8 +248,8 @@ scikit-image==0.22.0 ; python_version >= "3.10" #Pinned versions: 0.20.3 #test that import: -scipy==1.10.1 ; python_version <= "3.11" -scipy==1.14.1 ; python_version >= "3.12" +scipy==1.13.1 ; python_version == "3.9" +scipy==1.14.1 ; python_version > "3.9" # Pin SciPy because of failing distribution tests (see #60347) #Description: scientific python #Pinned versions: 1.10.1 @@ -265,7 +263,7 @@ scipy==1.14.1 ; python_version >= "3.12" #test that import: # needed by torchgen utils -typing-extensions>=4.10.0 +typing_extensions==4.15.0 #Description: type hints for python #Pinned versions: #test that import: @@ -275,7 +273,7 @@ typing-extensions>=4.10.0 #Pinned versions: #test that import: -unittest-xml-reporting<=3.2.0,>=2.0.0 +unittest-xml-reporting==3.2.0 #Description: saves unit test results to xml #Pinned versions: #test that import: @@ -286,7 +284,7 @@ lintrunner==0.12.7 #Pinned versions: 0.12.7 #test that import: -redis>=4.0.0 +redis==6.4.0 #Description: redis database #test that import: anything that tests OSS caching/mocking (inductor/test_codecache.py, inductor/test_max_autotune.py) @@ -310,8 +308,7 @@ z3-solver==4.15.1.0 ; platform_machine != "s390x" #Pinned versions: #test that import: -tensorboard==2.13.0 ; python_version < "3.13" -tensorboard==2.18.0 ; python_version >= "3.13" +tensorboard==2.18.0 #Description: Also included in .ci/docker/requirements-docs.txt #Pinned versions: #test that import: test_tensorboard @@ -323,7 +320,8 @@ pywavelets==1.7.0 ; python_version >= "3.12" #Pinned versions: 1.4.1 #test that import: -lxml==5.3.0 +lxml==5.3.0 ; python_version <= "3.12" +lxml==6.0.0 ; python_version == "3.13" #Description: This is a requirement of unittest-xml-reporting # Python-3.9 binaries @@ -335,8 +333,9 @@ sympy==1.13.3 #Pinned versions: #test that import: -onnx==1.18.0 -#Description: Required by onnx tests, and mypy and test_public_bindings.py when checking torch.onnx._internal +onnx==1.16.1 ; python_version <= "3.12" +onnx==1.18.0 ; python_version == "3.13" +#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: @@ -360,10 +359,10 @@ pwlf==2.2.1 #test that import: test_sac_estimator.py # To build PyTorch itself -pyyaml -pyzstd -setuptools>=70.1.0 -six +PyYAML==6.0.3 +pyzstd==0.18.0 +setuptools==79.0.1 +six==1.17.0 scons==4.5.2 ; platform_machine == "aarch64" diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index f12a3ac075175..aa82d36aa7ce6 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -5,7 +5,9 @@ export TZ=UTC tagged_version() { GIT_DIR="${workdir}/pytorch/.git" GIT_DESCRIBE="git --git-dir ${GIT_DIR} describe --tags --match v[0-9]*.[0-9]*.[0-9]*" - if [[ ! -d "${GIT_DIR}" ]]; then + if [[ -n "${CIRCLE_TAG:-}" ]]; then + echo "${CIRCLE_TAG}" + elif [[ ! -d "${GIT_DIR}" ]]; then echo "Abort, abort! Git dir ${GIT_DIR} does not exists!" kill $$ elif ${GIT_DESCRIBE} --exact >/dev/null; then @@ -69,6 +71,8 @@ fi export PYTORCH_BUILD_NUMBER=1 +# This part is done in the builder scripts so commenting the duplicate code +: <<'BLOCK_COMMENT' # Set triton version as part of PYTORCH_EXTRA_INSTALL_REQUIREMENTS TRITON_VERSION=$(cat $PYTORCH_ROOT/.ci/docker/triton_version.txt) TRITON_CONSTRAINT="platform_system == 'Linux'" @@ -110,6 +114,7 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_B export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS} | ${TRITON_REQUIREMENT}" fi fi +BLOCK_COMMENT USE_GLOO_WITH_OPENSSL="ON" if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 11fa8404273d3..f2851e3317256 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os +import re import shutil import sys from pathlib import Path @@ -50,6 +51,30 @@ def patch_init_py( with open(path, "w") as f: f.write(orig) +def get_rocm_version() -> str: + rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm" + rocm_version = "0.0.0" + rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h" + if not os.path.isfile(rocm_version_h): + rocm_version_h = f"{rocm_path}/include/rocm_version.h" + # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install. + if os.path.isfile(rocm_version_h): + RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)") + RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)") + RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)") + major, minor, patch = 0, 0, 0 + for line in open(rocm_version_h): + match = RE_MAJOR.search(line) + if match: + major = int(match.group(1)) + match = RE_MINOR.search(line) + if match: + minor = int(match.group(1)) + match = RE_PATCH.search(line) + if match: + patch = int(match.group(1)) + rocm_version = str(major)+"."+str(minor)+"."+str(patch) + return rocm_version def build_triton( *, @@ -64,14 +89,24 @@ def build_triton( if "MAX_JOBS" not in env: max_jobs = os.cpu_count() or 1 env["MAX_JOBS"] = str(max_jobs) - + if not release: + # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8 + # while release build should only include the version, i.e. 2.1.0 + rocm_version = get_rocm_version() + version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}" + version += version_suffix with TemporaryDirectory() as tmpdir: triton_basedir = Path(tmpdir) / "triton" triton_pythondir = triton_basedir / "python" triton_repo = "https://github.com/openai/triton" if device == "rocm": - triton_pkg_name = "pytorch-triton-rocm" + triton_repo = "https://github.com/ROCm/triton" + rocm_version = get_rocm_version() # e.g., "7.0.1" + if tuple(map(int, rocm_version.split("."))) > (7, 0, 0): + triton_pkg_name = "triton" + else: + triton_pkg_name = "pytorch-triton-rocm" elif device == "xpu": triton_pkg_name = "pytorch-triton-xpu" triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton" @@ -89,6 +124,7 @@ def build_triton( # change built wheel name and version env["TRITON_WHEEL_NAME"] = triton_pkg_name + env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix if with_clang_ldd: env["TRITON_BUILD_WITH_CLANG_LLD"] = "1" diff --git a/CMakeLists.txt b/CMakeLists.txt index ce7890f002d3b..91181735750d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -873,7 +873,7 @@ cmake_dependent_option( "Whether to build the flash_attention kernel for scaled dot product attention.\ Will be disabled if not supported by the platform" ON - "USE_CUDA OR USE_ROCM;NOT MSVC" + "(USE_CUDA AND NOT MSVC) OR USE_ROCM" OFF) cmake_dependent_option( @@ -908,7 +908,7 @@ cmake_dependent_option( # USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake # if(USE_ROCM) - if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)) + if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) include(cmake/External/aotriton.cmake) endif() endif() diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6c095680733fe..b30d8336e8ec9 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -301,13 +301,14 @@ IF(USE_FBGEMM_GENAI) # Add additional HIPCC compiler flags for performance set(FBGEMM_GENAI_EXTRA_HIPCC_FLAGS - -mllvm - -amdgpu-coerce-illegal-types=1 -mllvm -enable-post-misched=0 -mllvm -greedy-reverse-local-assignment=1 -fhip-new-launch-api) + if(DEFINED ROCM_VERSION_DEV AND ROCM_VERSION_DEV VERSION_LESS "7.2.0") + list(PREPEND FBGEMM_GENAI_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) + endif() hip_add_library( fbgemm_genai STATIC diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index 4d48084b0ab89..7a8d02be530e3 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -180,7 +180,7 @@ void Context::setUserEnabledNNPACK(bool e) { } bool Context::allowTF32CuDNN(const std::string& op) const { - if (op.size() == 0){ + if (op.empty()){ bool allow_tf32_rnn = float32Precision("cuda", "rnn") == "tf32"; bool allow_tf32_conv = float32Precision("cuda", "conv") == "tf32"; TORCH_CHECK( @@ -281,9 +281,6 @@ bool Context::userEnabledOverrideableSDP() const { static constexpr const auto cublas_config_var_name = "CUBLAS_WORKSPACE_CONFIG"; static constexpr const std::array cublas_deterministic_configs = {":4096:8", ":16:8"}; -#ifdef USE_ROCM -static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32"; -#endif bool Context::checkCuBLASConfigDeterministic() { // If using CUDA 10.2 or greater, need to make sure CuBLAS workspace config @@ -343,12 +340,6 @@ void Context::setImmediateMiopen(bool b) { } bool Context::allowTF32CuBLAS() const { -#ifdef USE_ROCM - const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); - if (allow_tf32 != true) { - return false; - } -#endif bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; bool allow_tf32_new = float32Precision("cuda", "matmul") == "tf32"; TORCH_CHECK( @@ -362,14 +353,6 @@ bool Context::allowTF32CuBLAS() const { } void Context::setAllowTF32CuBLAS(bool b) { -#ifdef USE_ROCM - const auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32); - if (allow_tf32 != true) { - C10_LOG_FIRST_N(INFO, 10) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. " - << "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it."; - return; - } -#endif float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST; setFloat32Precision("cuda", "matmul", b ? "tf32" : "ieee"); } @@ -443,7 +426,7 @@ void Context::setFloat32Precision(const std::string& backend, const std::string& std::string msg; auto iterp = _fp32_precisions.find(backend); TORCH_CHECK(iterp != _fp32_precisions.end()); - for (auto p : iterp->second) { + for (const auto& p : iterp->second) { msg += p; msg += " "; } diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index 4300217235b84..06bcc5d4f49b8 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -6,8 +6,6 @@ #include #include #include -#include -#include #include namespace c10 { @@ -19,9 +17,6 @@ class OperatorHandle; struct OperatorKernel; class KernelFunction; -class KernelToken; -class SafeKernelFunction; - template using has_symint = std::disjunction< std::is_same, @@ -95,12 +90,6 @@ class TORCH_API KernelFunction final { BoxedKernel::BoxedKernelFunction_withDispatchKeys; KernelFunction(); - ~KernelFunction(); - - KernelFunction(const KernelFunction& other); - KernelFunction& operator=(const KernelFunction& other); - - KernelFunction(KernelFunction&&) noexcept = default; // Fast path for dispatch to allow not touching the boxed kernel in // the common case where unboxed is available. @@ -273,9 +262,6 @@ class TORCH_API KernelFunction final { // For testing internal invariants only bool _equalsBoxedAndUnboxed(const KernelFunction&) const; - // Register a token to be invalidated when this KernelFunction is destroyed - void registerToken(std::weak_ptr token) const; - private: explicit KernelFunction( std::unique_ptr functor, @@ -290,50 +276,6 @@ class TORCH_API KernelFunction final { BoxedKernel boxed_kernel_func_; void* unboxed_kernel_func_; void* sym_unboxed_kernel_func_; - // List of tokens that need to be invalidated when this KernelFunction is - // destroyed (lazy allocation to save memory when empty) - mutable std::unique_ptr>> tokens_; -}; - -// Token held by SafeKernelFunction that gets invalidated when KernelFunction is -// destroyed -class KernelToken { - public: - bool isValid() const; - void invalidate(); - - private: - std::atomic invalid_{false}; -}; - -class SafeKernelFunction { - public: - SafeKernelFunction( - const KernelFunction* kernel, - std::string debug, - std::shared_ptr opHandle); - - // Safe callBoxed - checks token validity first - void callBoxed( - const OperatorHandle& opHandle, - DispatchKeySet dispatchKeySet, - Stack* stack) const; - - // Get debug information - const std::string& debug() const { - return debug_; - } - - // Get the OpHandle that lives on this SafeKernelFunction - const OperatorHandle& opHandle() const { - return *opHandle_; - } - - private: - KernelFunction kernel_; - std::shared_ptr token_; - std::string debug_; - std::shared_ptr opHandle_; }; } // namespace c10 diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index 672309ec19a2c..a89a0e8952b6e 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -24,36 +24,6 @@ inline KernelFunction::KernelFunction() unboxed_kernel_func_(nullptr), sym_unboxed_kernel_func_(nullptr) {} -inline KernelFunction::~KernelFunction() { - if (tokens_) { - for (auto& weak_token : *tokens_) { - if (auto token = weak_token.lock()) { - token->invalidate(); - } - } - } -} - -inline KernelFunction::KernelFunction(const KernelFunction& other) - : boxed_kernel_func_(other.boxed_kernel_func_), - unboxed_kernel_func_(other.unboxed_kernel_func_), - sym_unboxed_kernel_func_(other.sym_unboxed_kernel_func_) { - // tokens_ is intentionally not copied as we only care about invalidating - // tokens if the original KernelFunction is destroyed -} - -inline KernelFunction& KernelFunction::operator=(const KernelFunction& other) { - if (this != &other) { - boxed_kernel_func_ = other.boxed_kernel_func_; - unboxed_kernel_func_ = other.unboxed_kernel_func_; - sym_unboxed_kernel_func_ = other.sym_unboxed_kernel_func_; - - // tokens_ is intentionally not copied as we only care about invalidating - // tokens if the original KernelFunction is destroyed - } - return *this; -} - inline KernelFunction::KernelFunction( std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, @@ -187,14 +157,6 @@ C10_ALWAYS_INLINE Return KernelFunction::call( std::forward(args)...); } -inline void KernelFunction::registerToken( - std::weak_ptr token) const { - if (!tokens_) { - tokens_ = std::make_unique>>(); - } - tokens_->push_back(std::move(token)); -} - inline KernelFunction KernelFunction::makeFromBoxedKernel( BoxedKernel boxed_fn) { return KernelFunction( @@ -355,38 +317,4 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) { std::forward(lambda))); } -inline bool KernelToken::isValid() const { - return !invalid_.load(std::memory_order_acquire); -} - -inline void KernelToken::invalidate() { - invalid_.store(true, std::memory_order_release); -} - -inline SafeKernelFunction::SafeKernelFunction( - const KernelFunction* kernel, - std::string debug, - std::shared_ptr opHandle) - : kernel_(kernel ? *kernel : KernelFunction()), - token_(std::make_shared()), - debug_(std::move(debug)), - opHandle_(std::move(opHandle)) { - // Register the token with the original kernel so it gets invalidated when the - // kernel is destroyed - if (kernel) { - kernel->registerToken(token_); - } -} - -inline void SafeKernelFunction::callBoxed( - const OperatorHandle& opHandle, - DispatchKeySet dispatchKeySet, - Stack* stack) const { - TORCH_CHECK( - token_ && token_->isValid(), - "SafeKernelFunction has been invalidated ", - debug_); - kernel_.callBoxed(opHandle, dispatchKeySet, stack); -} - } // namespace c10 diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 43eb0028c70fe..bc043df6a93e9 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -487,10 +487,6 @@ class TORCH_API OperatorHandle { return operatorDef_->op.hasComputedKernelForDispatchKey(k); } - SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const { - return operatorDef_->op.getComputedKernelForDispatchKey(k); - } - std::string dumpComputedTable() const { return operatorDef_->op.dumpComputedTable(); } diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index c172e9b9c6096..b4063fb720be0 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -315,42 +315,6 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat return nullptr; } -SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey( - DispatchKey k) const { - TORCH_CHECK( - !isAliasDispatchKey(k), - "Alias keys do not have runtime kernel registrations."); - const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k); - TORCH_CHECK( - dispatchTable_[dispatch_ix].isValid(), - "no kernel for ", - k, - " for ", - name_); - - // Get the KernelFunction object from kernels_ to pass to SafeKernelFunction - - // The KernelFunction object in dispatchTable_ is a copy of the KernelFunction - // in the AnnotatedKernel in kernels_. A KernelFunction is only truly - // deregistered when the kernel is removed from kernels_. However, the - // KernelFunction in dispatchTable_ might be removed before it is deregistered - // (when a newer kernel is registered). Therefore, here we want to return a - // SafeKernelFunction that is backed by the original KernelFunction in - // kernels_, so that we only invalidate it when the kernel is deregistered. - auto [annotatedKernel, _] = - computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k); - - // Use findSchemaOrThrow to get OpHandle for the OperatorEntry - auto& dispatcher = c10::Dispatcher::singleton(); - auto opHandle = dispatcher.findSchemaOrThrow( - name_.name.c_str(), name_.overload_name.c_str()); - - return SafeKernelFunction( - &annotatedKernel.kernel, - annotatedKernel.debug, - std::make_shared(opHandle)); -} - const std::vector& OperatorEntry::getTags() const { #if defined C10_MOBILE TORCH_CHECK(false, "tags are not saved for Mobile"); diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 59b54ce1d9d32..83200ff9c94ff 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -217,8 +217,6 @@ class TORCH_API OperatorEntry final { const KernelFunction& kernelForDispatchKey(DispatchKey k) const; // Returns true if the "computed table" has an entry for a particular key. bool hasComputedKernelForDispatchKey(DispatchKey k) const; - // Returns a KernelFunction corresponding to the kernel in dispatchTable - SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const; // Returns all the operator tags added at the time of registration const std::vector& getTags() const; void setReportErrorCallback_(std::unique_ptr callback); diff --git a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh index 60e1a19c1aacf..a65db3f2df12a 100644 --- a/aten/src/ATen/cuda/detail/OffsetCalculator.cuh +++ b/aten/src/ATen/cuda/detail/OffsetCalculator.cuh @@ -45,6 +45,24 @@ struct OffsetCalculator { C10_HOST_DEVICE offset_type get(index_t linear_idx) const { offset_type offsets; + +#if defined(USE_ROCM) + if ((dims > 0) && (dims <= 2)) { + auto divmod = sizes_[0].divmod(linear_idx); + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) + offsets[arg] = divmod.mod * strides_[0][arg]; + if (dims >= 2) { + divmod = sizes_[1].divmod(divmod.div); + #pragma unroll + for (int arg = 0; arg < NARGS; arg++) + offsets[arg] += divmod.mod * strides_[1][arg]; + } + // [...] + return offsets; + } +#endif + #pragma unroll for (int arg = 0; arg < NARGS; arg++) { offsets[arg] = 0; diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 12ad84a15b180..ee28c5c1693f4 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -999,12 +999,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { dtypes[i] = iter.dtype(i); } auto offset_calc = ::make_offset_calculator(iter); +#ifdef USE_ROCM + constexpr int grp_sz = 128; + launch_legacy_kernel_manual_unroll(numel, [=] GPU_LAMBDA(int idx, bool unrl) { + if (unrl) { + auto offsets0 = offset_calc.get(idx); + auto offsets1 = offset_calc.get(idx + grp_sz); + auto offsets2 = offset_calc.get(idx + grp_sz * 2); + auto offsets3 = offset_calc.get(idx + grp_sz * 3); + void* out0 = data[0] + offsets0[0]; + void* out1 = data[0] + offsets1[0]; + void* out2 = data[0] + offsets2[0]; + void* out3 = data[0] + offsets3[0]; + arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1); + arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1); + arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1); + arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out0, result0); + c10::cast_and_store(dtypes[0], out1, result1); + c10::cast_and_store(dtypes[0], out2, result2); + c10::cast_and_store(dtypes[0], out3, result3); + } else { + auto offsets = offset_calc.get(idx); + void* out = data[0] + offsets[0]; + arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); + c10::cast_and_store(dtypes[0], out, result); + } + }); +#else launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) { auto offsets = offset_calc.get(idx); void* out = data[0] + offsets[0]; arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1); c10::cast_and_store(dtypes[0], out, result); }); +#endif } } diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 59b0426bab1f0..62a07e1e28c86 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) { }); } +#ifdef USE_ROCM +void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { + gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) { + return static_cast(value); + }); +} +void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { + gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) { + return static_cast(value); + }); +} +#endif + void float8_copy_kernel_cuda(TensorIteratorBase &iter) { ScalarType dtype = iter.dtype(0); ScalarType other_dtype = iter.dtype(1); @@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { } else { float16_copy_kernel_cuda(iter); } - } else if (isBitsType(dtype)) { + } +#ifdef USE_ROCM + else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) { + if (iter.dtype(1) == kBFloat16) { + bfloat16tofloat32_copy_kernel_cuda(iter); + } else { + float16tofloat32_copy_kernel_cuda(iter); + } + } +#endif + else if (isBitsType(dtype)) { TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting " "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype); AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 02feb55cb69d6..dacef18c79b68 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -59,7 +59,7 @@ constexpr uint64_t getDefaultMaxThreadsPerBlock() { #ifdef USE_ROCM #define SKIP_SORTED_INDICES 32 template -__global__ void indexing_backward_kernel( +__global__ void indexing_backward_kernel_many_indices( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) { using opmath_t = at::opmath_type; @@ -254,7 +254,8 @@ __global__ void indexing_backward_kernel_stride_1( } } } -#else +#endif + template __global__ void indexing_backward_kernel( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -333,6 +334,7 @@ __global__ void indexing_backward_kernel( } } +#ifndef USE_ROCM template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -708,6 +710,9 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size))) : grid.y, + grid.z); dim3 new_grid(ceil_div(num_indices, (int64_t) (indices_per_block * warp_size)), grid.y, grid.z); size_t smem_dups_size = indices_per_block * warp_size * sizeof(int64_t); #define KERNEL_GRID new_grid @@ -780,11 +785,43 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List= 200000) + AT_DISPATCH_V2( + expandedValue.scalar_type(), + "indexing_backward_many_indices", + AT_WRAP([&] { + indexing_backward_kernel_many_indices<<>>( + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore, + accumulate); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + // AT_EXPAND(AT_FLOAT8_TYPES), + // TODO(#113663): clean up accumulation behavior in float8 dtypes, accumulate=True + // should not be supported here, then reenable AT_FLOAT8_DTYPES + kFloat8_e4m3fn, + kFloat8_e5m2, + kFloat8_e4m3fnuz, + kFloat8_e5m2fnuz, + kComplexHalf, + kHalf, + kBool, + kBFloat16); + else +#endif AT_DISPATCH_V2( expandedValue.scalar_type(), "indexing_backward", AT_WRAP([&] { - indexing_backward_kernel<<>>( + indexing_backward_kernel<<>>( sorted_indices.const_data_ptr(), orig_indices.const_data_ptr(), expandedValue.const_data_ptr(), diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 088aa517aa23a..0e3fc88b569c0 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -23,7 +23,7 @@ namespace at::native { // The maximum number of threads in a block #if defined(USE_ROCM) -constexpr int MAX_BLOCK_SIZE = 256; +constexpr int MAX_BLOCK_SIZE = 1024; #else constexpr int MAX_BLOCK_SIZE = 512; #endif @@ -33,7 +33,7 @@ constexpr unsigned MAX_GRID_SIZE = 65535u; // Number of threads in a block given an input size up to MAX_BLOCK_SIZE static int getNumThreads(int nElem) { #if defined(USE_ROCM) - int threadSizes[5] = { 16, 32, 64, 128, MAX_BLOCK_SIZE }; + int threadSizes[5] = { 64, 128, 256, 512, MAX_BLOCK_SIZE }; #else int threadSizes[5] = { 32, 64, 128, 256, MAX_BLOCK_SIZE }; #endif @@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { // first the reductions each thread does separately scalar_t sum = static_cast(0); for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) { +#if defined(USE_ROCM) + constexpr int UNRL = 4; // load deserilize factor + scalar_t tmp[UNRL]; + for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) { +#pragma unroll + for (int u = 0; u < UNRL; u++) + tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x))); +#pragma unroll + for (int u = 0; u < UNRL; u++) + if (x+u*blockDim.x < tensor.size(2)) + sum += tmp[u]; + } +#else for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) { sum += op(batch, plane, x); } +#endif } __shared__ scalar_t shared[C10_WARP_SIZE]; SumReduceOp reduce_op; @@ -292,6 +306,22 @@ __global__ void batch_norm_collect_statistics_kernel( stat_accscalar_t var_n = 0; int n = 0; for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) { +#if defined(USE_ROCM) + constexpr int UNRL = 4; + stat_accscalar_t v_[UNRL]; + for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) { + for (int u = 0; u < UNRL; u++) + v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)]; + for (int u = 0; u < UNRL; u++) { + if (x+u*blockDim.x < input.size(2)) { + stat_accscalar_t d1 = v_[u] - avg; + n++; + avg += d1 / n; + var_n += d1 * (v_[u] - avg); + } + } + } +#else for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) { stat_accscalar_t v = input[batch][plane][x]; stat_accscalar_t d1 = v - avg; @@ -299,6 +329,7 @@ __global__ void batch_norm_collect_statistics_kernel( avg += d1 / n; var_n += d1 * (v - avg); } +#endif } // first warpSum to get one value per thread to diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index b891750891d58..b46bbaa6500b9 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame( } } +#ifdef USE_ROCM +// Helper function to compute output pixel range that can contribute to input pixel +template +__device__ __forceinline__ void compute_output_range( + int input_pos, + accscalar_t scale, + int output_size, + bool align_corners, + int& min_output, + int& max_output) { + accscalar_t lo, hi; + if (align_corners) { + lo = static_cast(input_pos - 1) / scale; + hi = static_cast(input_pos + 1) / scale; + } else { + lo = (input_pos - static_cast(0.5)) / scale - static_cast(0.5); + hi = (input_pos + static_cast(1.5)) / scale - static_cast(0.5); + } + min_output = max(0, static_cast(std::ceil(lo))); + max_output = min(output_size - 1, static_cast(std::floor(hi))); +} +#endif + // Backward (adjoint) operation 1 <- 2 (accumulates) template C10_LAUNCH_BOUNDS_1(1024) @@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame( const bool align_corners, scalar_t* __restrict__ idata, const scalar_t* __restrict__ odata) { - const size_t o_numel = nc * width2 * height2; + // In C++, integer multiplication, like in standard arithmetic, is generally commutative. const size_t i_numel = nc * width1 * height1; +#ifdef USE_ROCM + for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel; + index += blockDim.x * gridDim.x) { + // Decode input pixel coordinates + size_t index_temp = index; + const int w1 = index_temp % width1; + index_temp /= width1; + const int h1 = index_temp % height1; + const size_t nc_idx = index_temp / height1; + + accscalar_t grad_sum = 0; + + // Find range of output pixels that could interpolate from this input pixel + int h2_min, h2_max, w2_min, w2_max; + compute_output_range(h1, rheight, height2, align_corners, h2_min, h2_max); + compute_output_range(w1, rwidth, width2, align_corners, w2_min, w2_max); + + // Iterate over potential output pixels + for (int h2 = h2_min; h2 <= h2_max; h2++) { + for (int w2 = w2_min; w2 <= w2_max; w2++) { + // Compute source coordinates for this output pixel + const accscalar_t h1r = area_pixel_compute_source_index( + rheight, h2, align_corners, /*cubic=*/false); + const int h1_base = (int)h1r; + const int h1p = (h1_base < height1 - 1) ? 1 : 0; + const accscalar_t h1lambda = h1r - h1_base; + const accscalar_t h0lambda = static_cast(1) - h1lambda; + + const accscalar_t w1r = area_pixel_compute_source_index( + rwidth, w2, align_corners, /*cubic=*/false); + const int w1_base = (int)w1r; + const int w1p = (w1_base < width1 - 1) ? 1 : 0; + const accscalar_t w1lambda = w1r - w1_base; + const accscalar_t w0lambda = static_cast(1) - w1lambda; + + // Check if our input pixel participates in this interpolation and accumulate all weights + // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse + // to the same pixel, so we need to accumulate weights from all matching positions + accscalar_t weight = 0; + + // Check all four interpolation positions and accumulate weights + if (h1 == h1_base && w1 == w1_base) { + weight += h0lambda * w0lambda; // top-left + } + if (h1 == h1_base && w1 == w1_base + w1p) { + weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0) + } + if (h1 == h1_base + h1p && w1 == w1_base) { + weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0) + } + if (h1 == h1_base + h1p && w1 == w1_base + w1p) { + weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions) + } + + if (weight > 0) { + const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2; + grad_sum += weight * static_cast(odata[output_idx]); + } + } + } + + // Write accumulated gradient (no atomics needed) + idata[index] = static_cast(grad_sum); + } +#else + const size_t o_numel = nc * width2 * height2; for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; index += blockDim.x * gridDim.x) { size_t index_temp = index; @@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame( static_cast(h1lambda * w1lambda * d2val), true); } +#endif } template @@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( // threads are not covering the whole input tensor. grad_input.zero_(); - const size_t num_kernels = nbatch * channels * output_height * output_width; const int num_threads = std::min( at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template( return; } +#ifdef USE_ROCM + constexpr bool use_input = true; +#else + constexpr bool use_input = false; +#endif + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { @@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales_w); + const size_t num_kernels = nbatch * channels * output_height * output_width; + upsample_bilinear2d_backward_nhwc_out_frame <<(num_threads)), num_threads, 0, stream>>>( input_height, @@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales_w); + const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width); + upsample_bilinear2d_backward_out_frame <<(num_threads)), num_threads, diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 940680eb3682f..81387bcceaf01 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -141,7 +141,11 @@ WelfordDataLN cuWelfordOnlineSum( if constexpr (!rms_norm){ U delta = val - curr_sum.mean; U new_count = curr_sum.count + 1.f; +#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count); +#else U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster +#endif return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; } else{ return {0.f, curr_sum.sigma2 + val * val, 0}; @@ -159,7 +163,11 @@ WelfordDataLN cuWelfordCombine( U count = dataA.count + dataB.count; U mean, sigma2; if (count > decltype(dataB.count){0}) { +#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + auto coef = __builtin_amdgcn_rcpf(count); +#else auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division +#endif auto nA = dataA.count * coef; auto nB = dataB.count * coef; mean = nA*dataA.mean + nB*dataB.mean; diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index b8b43e0086c1a..c2193f2378dd5 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -95,6 +95,72 @@ #endif #endif +#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)) +namespace pytorch_flash +{ +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor> +mha_fwd( + const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + std::optional& + out_, // batch_size x seqlen_q x num_heads x head_size + std::optional& + alibi_slopes_, // num_heads or batch_size x num_heads + const float p_dropout, + const float softmax_scale, + bool is_causal, + std::optional window_size_left, + std::optional window_size_right, + const float softcap, + const bool return_softmax, + std::optional gen_) { +#if defined(USE_ROCM_CK_SDPA) + if (at::globalContext().getROCmFAPreferredBackend() == + at::ROCmFABackend::Ck) { + const int non_null_window_left = window_size_left.value_or(-1); + const int non_null_window_right = window_size_right.value_or(-1); + std::optional dummy_attn_bias = std::nullopt; + return mha_fwd_ck( + q, + k, + v, + out_, + p_dropout, + softmax_scale, + is_causal, + non_null_window_left, + non_null_window_right, + return_softmax, + gen_, + dummy_attn_bias); // Not used in flash attention + } +#endif + return mha_fwd_aot( + q, + k, + v, + out_, + alibi_slopes_, + p_dropout, + softmax_scale, + is_causal, + window_size_left, + window_size_right, + return_softmax, + gen_); +} +} +#endif + namespace at { namespace cuda::philox { diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 660aee3647cea..8eec0de7773f3 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -176,6 +176,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { } return false; } + if constexpr(caller_is_meff) { + bool is_half = (params.query.dtype() == at::kHalf) || + (params.query.dtype() == at::kBFloat16); + const int64_t alignment = is_half ? 8 : 4; + if (!(query_size_last % alignment == 0 && query_size_last > 0 && + value_size_last % alignment == 0 && value_size_last > 0)) { + if (debug) { + TORCH_WARN( + "Mem efficient attention requires last dimension of inputs to be divisible by ", + alignment, + ". ", + "Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead."); + } + return false; + } + } return true; } diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index b5b1ed4292896..2467cb809fdbf 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -462,10 +462,11 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; + using sdp::aotriton_adapter::mk_atomictensor; using sdp::aotriton_adapter::cast_dtype; at::Tensor atomic_counter; if (is_causal) { - atomic_counter = at::zeros({1}, q.options()); + atomic_counter = at::zeros({1}, q.options().dtype(at::kInt)); } aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t); @@ -474,7 +475,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto nullscalar = mk_philoxtensor(nullptr); auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; - auto persistent_counter = is_causal ? mk_philoxtensor(atomic_counter.data_ptr()) : nullscalar; + auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); if (uses_swa || AOTRITON_ALWAYS_V3_API) { #if AOTRITON_V3_API using aotriton::v3::flash::CausalType; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h index f6f2240d4f091..71a1959065970 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h @@ -270,7 +270,7 @@ std::tuple mha_varle #endif TORCH_API -inline std::tuple< +std::tuple< at::Tensor, at::Tensor, at::Tensor, @@ -294,42 +294,7 @@ mha_fwd( std::optional window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { -#if defined(USE_ROCM_CK_SDPA) - if (at::globalContext().getROCmFAPreferredBackend() == - at::ROCmFABackend::Ck) { - const int non_null_window_left = window_size_left.value_or(-1); - const int non_null_window_right = window_size_right.value_or(-1); - std::optional dummy_attn_bias = std::nullopt; - return mha_fwd_ck( - q, - k, - v, - out_, - p_dropout, - softmax_scale, - is_causal, - non_null_window_left, - non_null_window_right, - return_softmax, - gen_, - dummy_attn_bias); // Not used in flash attention - } -#endif - return mha_fwd_aot( - q, - k, - v, - out_, - alibi_slopes_, - p_dropout, - softmax_scale, - is_causal, - window_size_left, - window_size_right, - return_softmax, - gen_); -} + std::optional gen_); inline std::tuple< at::Tensor, diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index ef5c2fd4e97de..daceebd8bc889 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1037,6 +1037,22 @@ if(USE_ROCM) list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling) endif(CMAKE_BUILD_TYPE MATCHES Debug) + # Get EnVar 'PYTORCH_LAYERNORM_FAST_RECIPROCAL' (or default to on). + if(DEFINED ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL}) + set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE $ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL}) + else() + set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE ON) + endif() + + set(PYTORCH_LAYERNORM_FAST_RECIPROCAL + ${PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE} + CACHE BOOL "Enable fast reciprocals within layer normalization." FORCE + ) + + if(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + add_definitions(-DPYTORCH_LAYERNORM_FAST_RECIPROCAL) + endif() + # needed for compat with newer versions of hip-clang that introduced C++20 mangling rules list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17) diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 5d91587746540..f09f77bedb80f 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -45,13 +45,89 @@ if(NOT __AOTRITON_INCLUDED) ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") + # Set the default __AOTRITON_LIB path + if(NOT WIN32) + set(__AOTRITON_LIB "lib/libaotriton_v2.so") + else() + set(__AOTRITON_LIB "lib/aotriton_v2.lib") + endif() + + function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + # Windows-specific dependencies - build these first + if(NOT noimage) + message(FATAL_ERROR "noimage must be ON for Windows builds") + endif() + # Build dlfcn-win32 + set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32") + set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install") + + ExternalProject_Add(${dlfcn-win32_external} + GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git + GIT_TAG v1.4.2 + PREFIX ${__DLFCN_WIN32_PREFIX} + INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_C_COMPILER=cl + -DCMAKE_CXX_COMPILER=cl + -DBUILD_SHARED_LIBS=ON + -DBUILD_TESTS=OFF + BUILD_BYPRODUCTS + "${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib" + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + ) + ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE) + + # Build xz/liblzma + set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz") + set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install") + + ExternalProject_Add(${xz_external} + GIT_REPOSITORY https://github.com/tukaani-project/xz.git + GIT_TAG v5.8.1 + PREFIX ${__XZ_PREFIX} + INSTALL_DIR ${__XZ_INSTALL_DIR} + CMAKE_ARGS + -DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR} + -DCMAKE_BUILD_TYPE=Release + -DBUILD_SHARED_LIBS=ON + -DENABLE_NLS=OFF + -DXZ_TOOL_LZMAINFO=OFF + -DXZ_TOOL_XZ=OFF + -DXZ_TOOL_XZDEC=OFF + -DXZ_TOOL_LZMADEC=OFF + BUILD_BYPRODUCTS + "${__XZ_INSTALL_DIR}/lib/lzma.lib" + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + ) + ExternalProject_Add_Step(${xz_external} copy_to_aotriton + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${__XZ_INSTALL_DIR}/bin/liblzma.dll" + "${__AOTRITON_INSTALL_DIR}/lib/" + DEPENDEES install + ) + set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE) + endfunction() + function(aotriton_build_from_source noimage project) if(noimage) SET(RECURSIVE "OFF") else() SET(RECURSIVE "ON") endif() + if(WIN32) + message(STATUS "Building AOTriton Windows dependencies") + aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR) + endif() message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}") + ExternalProject_Add(${project} GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_SUBMODULES_RECURSE ${RECURSIVE} @@ -65,12 +141,18 @@ if(NOT __AOTRITON_INCLUDED) -DAOTRITON_GPU_BUILD_TIMEOUT=0 -DAOTRITON_NO_PYTHON=ON -DAOTRITON_NOIMAGE_MODE=${noimage} - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + -DHIP_PLATFORM=amd + $<$:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}> + $<$:-Dliblzma_DIR=${liblzma_DIR}> + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_BUILD TRUE USES_TERMINAL_INSTALL TRUE ) + if(WIN32) + add_dependencies(${project} dlfcn-win32_external xz_external) + endif() endfunction() set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) @@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED) INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${__AOTRITON_INSTALL_DIR}" - BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so" + BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}" ) message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") @@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED) string(CONCAT __AOTRITON_URL "${__AOTRITON_BASE_URL}" "${__AOTRITON_VER}/${__AOTRITON_FILE}") + + # Set up directories + set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image}) + set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}) + set(__DOWNLOAD_NO_EXTRACT "") + set(__BUILD_COMMANDS "") + + # On Windows, we need custom tar extraction with UTF-8 support + if(WIN32) + set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE") + set(__BUILD_COMMANDS + COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}" + COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}" + ) + set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton) + endif() + ExternalProject_Add(${project} URL "${__AOTRITON_URL}" URL_HASH SHA256=${__AOTRITON_SHA256} - SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image} + DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} + ${__DOWNLOAD_NO_EXTRACT} + SOURCE_DIR ${__AOTRITON_EXTRACT_DIR} CONFIGURE_COMMAND "" BUILD_COMMAND "" + ${__BUILD_COMMANDS} INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory - "${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}" + "${__AOTRITON_INSTALL_SOURCE_DIR}" "${__AOTRITON_INSTALL_DIR}" BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" @@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED) endforeach() endforeach() endif() - target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so) + target_link_libraries(__caffe2_aotriton INTERFACE "${__AOTRITON_INSTALL_DIR}/${__AOTRITON_LIB}") target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) set(AOTRITON_FOUND TRUE) endif() # __AOTRITON_INCLUDED diff --git a/cmake/Modules/FindOpenBLAS.cmake b/cmake/Modules/FindOpenBLAS.cmake index 9ba86ba1ee0f4..21ae9e2521eb2 100644 --- a/cmake/Modules/FindOpenBLAS.cmake +++ b/cmake/Modules/FindOpenBLAS.cmake @@ -29,10 +29,15 @@ SET(Open_BLAS_LIB_SEARCH_PATHS $ENV{OpenBLAS}/lib $ENV{OpenBLAS_HOME} $ENV{OpenBLAS_HOME}/lib - ) +) + +SET(Open_BLAS_LIB_NAME openblas) +IF(DEFINED ENV{OpenBLAS_LIB_NAME}) + SET(Open_BLAS_LIB_NAME $ENV{OpenBLAS_LIB_NAME}) +ENDIF() FIND_PATH(OpenBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${Open_BLAS_INCLUDE_SEARCH_PATHS}) -FIND_LIBRARY(OpenBLAS_LIB NAMES openblas PATHS ${Open_BLAS_LIB_SEARCH_PATHS}) +FIND_LIBRARY(OpenBLAS_LIB NAMES ${Open_BLAS_LIB_NAME} PATHS ${Open_BLAS_LIB_SEARCH_PATHS}) SET(OpenBLAS_FOUND ON) diff --git a/docs/source/library.md b/docs/source/library.md index b31ca95d5b6a3..9d706e2e1080e 100644 --- a/docs/source/library.md +++ b/docs/source/library.md @@ -56,7 +56,6 @@ via PyTorch's C++ operator registration APIs). .. autofunction:: infer_schema .. autoclass:: torch._library.custom_ops.CustomOpDef :members: set_kernel_enabled -.. autofunction:: get_kernel ``` ## Low-level APIs diff --git a/related_commits b/related_commits new file mode 100644 index 0000000000000..b96cf18c181ab --- /dev/null +++ b/related_commits @@ -0,0 +1,10 @@ +ubuntu|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex +centos|pytorch|apex|release/1.9.0|07c3ee5347294b7a07a65c2c3596f1b14c7d3daa|https://github.com/ROCm/apex +ubuntu|pytorch|torchvision|release/0.24|b919bd0c56abbb3c5ca056a3a458af9fd1cabf52|https://github.com/pytorch/vision +centos|pytorch|torchvision|release/0.24|b919bd0c56abbb3c5ca056a3a458af9fd1cabf52|https://github.com/pytorch/vision +ubuntu|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data +centos|pytorch|torchdata|release/0.11|377e64c1be69a9be6649d14c9e3664070323e464|https://github.com/pytorch/data +ubuntu|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|https://github.com/pytorch/audio +centos|pytorch|torchaudio|release/2.9|e3c6ee2b6588b7cd27a84182de74bf12fe043831|https://github.com/pytorch/audio +ubuntu|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao +centos|pytorch|ao|main|a52a64aeb84fa6ff683ec2c7c42b97e27651a619|https://github.com/pytorch/ao diff --git a/requirements-build.txt b/requirements-build.txt index be19d987f73db..3d21094159d79 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -1,10 +1,10 @@ # Build System requirements -setuptools>=70.1.0,<80.0 # setuptools develop deprecated on 80.0 -cmake>=3.27 -ninja -numpy -packaging -pyyaml -requests -six # dependency chain: NNPACK -> PeachPy -> six -typing-extensions>=4.10.0 +setuptools==79.0.1 +cmake==4.0.0 +ninja==1.11.1.3 +numpy==2.1.2 +packaging==25.0 +pyyaml==6.0.3 +requests==2.32.5 +six==1.17.0 # dependency chain: NNPACK -> PeachPy -> six +typing_extensions==4.15.0 diff --git a/requirements.txt b/requirements.txt index fc4b53dfd49ea..824ca112602a0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,15 +5,18 @@ # Install / Development extra requirements build[uv] # for building sdist and wheel -expecttest>=0.3.0 -filelock -fsspec>=0.8.5 -hypothesis -jinja2 -lintrunner ; platform_machine != "s390x" and platform_machine != "riscv64" -networkx>=2.5.1 -optree>=0.13.0 -psutil -sympy>=1.13.3 -typing-extensions>=4.13.2 +expecttest==0.3.0 +filelock==3.20.0 +fsspec==2025.9.0 +hypothesis==5.35.1 +Jinja2==3.1.6 +lintrunner==0.12.7 ; platform_machine != "s390x" and platform_machine != "riscv64" +networkx==2.8.8 +ninja==1.11.1.3 +numpy==2.0.2 ; python_version == "3.9" +numpy==2.1.2 ; python_version > "3.9" +optree==0.13.0 +psutil==7.1.0 +sympy==1.13.3 +typing_extensions==4.15.0 wheel diff --git a/setup.py b/setup.py index 11ca48482a761..ae0097465da66 100644 --- a/setup.py +++ b/setup.py @@ -162,6 +162,10 @@ # USE_ROCM_CK_SDPA=1 # Enable building CK SDPA backend in ROCm platform # +# PYTORCH_LAYERNORM_FAST_RECIPROCAL +# If set, enables the use of builtin functions for fast reciprocals (1/x) w.r.t. +# layer normalization. Default: enabled. +# # Environment variables we respect (these environment variables are # conventional and are often understood/set by other software.) # diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py index e930ff787a9a4..ce456596fd55e 100644 --- a/test/dynamo/test_graph_region_tracker.py +++ b/test/dynamo/test_graph_region_tracker.py @@ -1,6 +1,5 @@ # Owner(s): ["module: dynamo"] import contextlib -import os import torch import torch.fx @@ -196,21 +195,6 @@ def fn(x, y, z): ) def test_mismatched_global_state(self): - @contextlib.contextmanager - def _hip_allow_tf32(): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - def inner_fn(x, y): x1 = x * 1 y1 = y + 1 @@ -251,31 +235,29 @@ def set_default_dtype_bfloat16(): def reset_default_dtype(): torch.set_default_dtype(old_dtype) - tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - with tf32_ctx(): - for ctx in [ - lambda: torch.set_grad_enabled(False), - torch.autograd.grad_mode.inference_mode, - lambda: torch.autograd.graph.disable_saved_tensors_hooks( - "This is not supported" - ), - # lambda: torch.set_num_threads(2), : Unsupported - (set_default_dtype_bfloat16, reset_default_dtype), - ( - lambda: torch.use_deterministic_algorithms(True), - lambda: torch.use_deterministic_algorithms(False), - ), - # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), - # lambda: torch.use_deterministic_algorithms(False)), : Unsupported - create_toggle_fns("allow_bf16_reduced_precision_reduction"), - create_toggle_fns("allow_fp16_reduced_precision_reduction"), - create_toggle_fns("allow_tf32"), - ]: - self.assertExpectedInline( - self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), - """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ + for ctx in [ + lambda: torch.set_grad_enabled(False), + torch.autograd.grad_mode.inference_mode, + lambda: torch.autograd.graph.disable_saved_tensors_hooks( + "This is not supported" + ), + # lambda: torch.set_num_threads(2), : Unsupported + (set_default_dtype_bfloat16, reset_default_dtype), + ( + lambda: torch.use_deterministic_algorithms(True), + lambda: torch.use_deterministic_algorithms(False), + ), + # (lambda: torch.use_deterministic_algorithms(True, warn_only=True), + # lambda: torch.use_deterministic_algorithms(False)), : Unsupported + create_toggle_fns("allow_bf16_reduced_precision_reduction"), + create_toggle_fns("allow_fp16_reduced_precision_reduction"), + create_toggle_fns("allow_tf32"), + ]: + self.assertExpectedInline( + self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx), + """[[['x1_2', 'y1_2', 'sum_3', 'o0'], ['x1_3', 'y1_3', 'sum_4', 'o2']], \ [['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""", - ) + ) def test_mutation_tracking_simple(self): def fn(x, y, z): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 1a9d8e8155e43..0a3891e2dc146 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -8421,43 +8421,24 @@ def write_state(state): def fn(x): return x + 1 - import contextlib - - @contextlib.contextmanager - def _hip_allow_tf32(): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - - tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - with tf32_ctx(): - initial_state = read_state() - y = torch.randn(10) - try: - for round in range(3): - for i in range(len(initial_state)): - new_state = [False] * len(initial_state) - new_state[i] = True - write_state(new_state) - assert read_state() == new_state - last_state.clear() - fn(y) - assert last_state == new_state - if round == 0: - assert cnt == i + 1 - else: - assert cnt == len(initial_state) - finally: - write_state(initial_state) + initial_state = read_state() + y = torch.randn(10) + try: + for round in range(3): + for i in range(len(initial_state)): + new_state = [False] * len(initial_state) + new_state[i] = True + write_state(new_state) + assert read_state() == new_state + last_state.clear() + fn(y) + assert last_state == new_state + if round == 0: + assert cnt == i + 1 + else: + assert cnt == len(initial_state) + finally: + write_state(initial_state) def test_grad_state_mutated(self): prior = torch.is_grad_enabled() diff --git a/test/inductor/test_async_compile.py b/test/inductor/test_async_compile.py index 5a61ea851eae0..cc94c4c95e01a 100644 --- a/test/inductor/test_async_compile.py +++ b/test/inductor/test_async_compile.py @@ -74,7 +74,14 @@ def f(a, b): return (a @ b).to(torch.float32).sum(dim=1) # Fake name to make sure the lookup table is name agnostic - func_def = """ + # When codegen/triton.py is changed, func_def must be updated + loop_header = ( + "for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):" + if torch.version.hip + else "for r0_offset in range(0, r0_numel, R0_BLOCK):" + ) + + func_def = f""" def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr): xnumel = 1024 r0_numel = 11776 @@ -87,7 +94,7 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons rbase = r0_base x0 = xindex _tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32) - for r0_offset in range(0, r0_numel, R0_BLOCK): + {loop_header} r0_index = r0_offset + r0_base r0_mask = r0_index < r0_numel roffset = r0_offset diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index 90399546d26ea..6523cddcec6db 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8) - @requires_cuda_and_triton - def test_persistent_reduction_no_x_dim(self): - def fn(x, y): - return x.sum(1), y.sum(1) - - inps = ( - torch.rand(16, 256, device="cuda"), - torch.rand(32, 256, device="cuda"), - ) - torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) - torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) - out_eager = fn(*inps) - out_compiled = torch.compile(fn)(*inps) - - self.assertEqual(out_eager, out_compiled) - self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @instantiate_parametrized_tests class ComboKernelDynamicShapesTests(TestCase): diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index 120d8d36b439d..849aefff8a965 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -43,9 +43,6 @@ Tolerances = namedtuple("Tolerances", ["atol", "rtol"]) -# In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. -# In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the -# logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. if torch.version.hip: torch.set_float32_matmul_precision("highest") else: diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 9ef3a18e24234..c67bde87a369b 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -109,9 +109,6 @@ def setUpClass(cls): if HAS_GPU: cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision() cls.prior_default_device = torch.get_default_device() - # In MI300, HIPBLASLT_ALLOW_TF32=1 is used to enable tf32 for matmul. - # In the current test, HIPBLASLT_ALLOW_TF32 is not set, according to the - # logic of allowTF32CuBLAS(), set float32_matmul_precision to highest. if torch.version.hip: torch.set_float32_matmul_precision("highest") else: diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 41db6b18daba7..6bde7a8c540a4 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -816,6 +816,7 @@ def test_2d_reduction_odd_shapes( # Check the code for multiple Rn_BLOCK's self._assert_reduction_ndims(code, 2) + @parametrize( "size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback", [ diff --git a/test/nn/test_multihead_attention.py b/test/nn/test_multihead_attention.py index c0419664d0098..40dca90b16488 100644 --- a/test/nn/test_multihead_attention.py +++ b/test/nn/test_multihead_attention.py @@ -17,7 +17,6 @@ instantiate_parametrized_tests, parametrize as parametrize_test, run_tests, - skipIfRocm, TEST_NUMPY, TEST_WITH_CROSSREF, ) @@ -746,7 +745,6 @@ def test_multihead_attn_nested_tensor_outside_fast_path(self): class TestMultiheadAttentionNNDeviceType(NNTestCase): - @skipIfRocm(msg="To investigate: yields NaN") def test_multihead_self_attn_two_masks_fast_path(self, device): """ Multihead self-attention should give the same result on the fast path (BetterTransformer) as on the slow path diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 1c31d5445f915..569d1bac85958 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -1480,8 +1480,8 @@ def to_np(value): self.assertRaisesRegex(RuntimeError, regex, base.pow_, exponent) elif torch.can_cast(torch.result_type(base, exponent), base.dtype): actual2 = actual.pow_(exponent) - self.assertEqual(actual, expected) - self.assertEqual(actual2, expected) + self.assertEqual(actual, expected.to(actual)) + self.assertEqual(actual2, expected.to(actual)) else: self.assertRaisesRegex( RuntimeError, diff --git a/test/test_cuda.py b/test/test_cuda.py index 293bb2b7e701b..d293601fad138 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -504,6 +504,9 @@ def test_out_of_memory_retry(self): IS_JETSON, "oom reporting has issues on jetson igx due to partial nvml support" ) def test_set_per_process_memory_fraction(self): + if torch.version.hip and ('gfx1101' in torch.cuda.get_device_properties(0).gcnArchName): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() orig = torch.cuda.get_per_process_memory_fraction(0) torch.cuda.reset_peak_memory_stats(0) try: @@ -759,53 +762,7 @@ def check_workspace_size(inp): torch._C._cuda_clearCublasWorkspaces() - @contextlib.contextmanager - def _hip_allow_tf32(self): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+ - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - - @unittest.skipIf(not TEST_WITH_ROCM, "not relevant for CUDA testing") - def test_hipblaslt_allow_tf32(self): - tf32_ctx = self._hip_allow_tf32 - with tf32_ctx(): - os.environ["HIPBLASLT_ALLOW_TF32"] = "0" - # Save original value of allow_tf32 - orig = torch.backends.cuda.matmul.allow_tf32 - # If allow_tf32 variable is declared as static in aten/src/ATen/Context.cpp - # then matmul.allow_tf32 will return False after this point even if - # HIP_BLASLT_ALLOW_TF32 is set to 1 and matmul.allow_tf32 is changed. - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - # Toggle torch.backends.cuda.matmul.allow_tf32 couple of times. - torch.backends.cuda.matmul.allow_tf32 = not orig - test1 = torch.backends.cuda.matmul.allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = orig - test2 = torch.backends.cuda.matmul.allow_tf32 - self.assertNotEqual(test1, test2) - # Restore original value of allow_tf32 - torch.backends.cuda.matmul.allow_tf32 = orig - def test_cublas_allow_tf32_get_set(self): - """ - We only turn on TF32 for MI300 with a special env var. This is because TF32 - is only available in MI300+ and is in experimental mode (hipblaslt support - is current WIP) - """ - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - - with tf32_ctx(): - self._test_cublas_allow_tf32_get_set_inner() - - def _test_cublas_allow_tf32_get_set_inner(self): skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] ) @@ -820,12 +777,6 @@ def _test_cublas_allow_tf32_get_set_inner(self): torch.backends.cuda.matmul.allow_tf32 = orig def test_float32_matmul_precision_get_set(self): - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - - with tf32_ctx(): - self._test_float32_matmul_precision_get_set_inner() - - def _test_float32_matmul_precision_get_set_inner(self): orig = torch.get_float32_matmul_precision() skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index 491648494f6f0..5a494f5487423 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -11,7 +11,6 @@ import tempfile import typing import unittest -from functools import partial from pathlib import Path from typing import * # noqa: F403 @@ -4157,148 +4156,6 @@ def test_any_output_is_alias_to_input_or_output(self): ) ) - def test_library_get_kernel(self): - """Test registering a custom kernel, using it, then deregistering and verifying error.""" - - # Register a dummy kernel for arange to the CPU key that returns a tensor of ones - def dummy_arange_cpu( - dispatch_keys, - start, - end, - dtype=None, - layout=torch.strided, - device=None, - pin_memory=False, - ): - size = max(0, int(end - start)) - return torch.ones(size, dtype=dtype, device=device) - - with torch.library._scoped_library("aten", "IMPL") as lib: - lib.impl("arange.start", dummy_arange_cpu, "CPU", with_keyset=True) - - kernel = torch.library.get_kernel("aten::arange.start", "CPU") - dispatch_keys = torch._C.DispatchKeySet(torch._C.DispatchKey.CPU) - result = kernel.call_boxed(dispatch_keys, 0, 5) - - self.assertEqual(result, torch.ones(5)) - - # The kernel should now be invalidated after exiting the scoped_library context - with self.assertRaisesRegex(RuntimeError, "has been invalidated"): - kernel.call_boxed(dispatch_keys, 0, 5) - - def test_library_get_kernel_with_conditional_dispatch(self): - """Test registering a custom kernel with conditional dispatch logic.""" - - def conditional_arange_cpu1( - original_kernel, - dispatch_keys, - start, - end, - dtype=None, - layout=torch.strided, - device=None, - pin_memory=False, - ): - # If end is even, use the original kernel, otherwise return ones tensor - if end % 2 == 0: - op_handle = torch.ops.aten.arange.start._handle - return original_kernel.call_boxed( - dispatch_keys, - start, - end, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - ) - else: - size = max(0, int(end - start)) - return torch.ones(size, dtype=dtype, device=device) - - def conditional_arange_cpu2( - original_kernel, - dispatch_keys, - start, - end, - dtype=None, - layout=torch.strided, - device=None, - pin_memory=False, - ): - # If start is even, use the original kernel, otherwise return twos tensor - if start % 2 == 0: - op_handle = torch.ops.aten.arange.start._handle - return original_kernel.call_boxed( - dispatch_keys, - start, - end, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - ) - else: - size = max(0, int(end - start)) - return torch.empty(size, dtype=dtype, device=device).fill_(2) - - original_kernel = torch.library.get_kernel("aten::arange.start", "CPU") - expected_result1, expected_result2 = torch.ones(5), torch.arange(0, 6) - expected_result3, expected_result4, expected_result5 = ( - torch.ones(5), - torch.arange(0, 6), - torch.ones(5).fill_(2), - ) - - with torch.library._scoped_library("aten", "IMPL") as lib2: - with torch.library._scoped_library("aten", "IMPL") as lib1: - lib1.impl( - "arange.start", - partial(conditional_arange_cpu1, original_kernel), - "CPU", - with_keyset=True, - ) - - self.assertEqual(torch.arange(0, 5), expected_result1) - self.assertEqual(torch.arange(0, 6), expected_result2) - new_original_kernel = torch.library.get_kernel( - "aten::arange.start", "CPU" - ) - lib2.impl( - "arange.start", - partial(conditional_arange_cpu2, new_original_kernel), - "CPU", - allow_override=True, - with_keyset=True, - ) - - self.assertEqual(torch.arange(0, 5), expected_result3) - self.assertEqual(torch.arange(0, 6), expected_result4) - self.assertEqual(torch.arange(1, 6), expected_result5) - - # The kernel should now be invalidated after destroying lib1 - with self.assertRaisesRegex(RuntimeError, "has been invalidated"): - torch.arange(0, 5) - - # Should still work after destroying lib1 - self.assertEqual(torch.arange(1, 6), expected_result5) - - def test_library_get_kernel_invalid(self): - """Test that get_kernel raises an error when no kernel is available.""" - with torch.library._scoped_library("test_invalid_kernel", "DEF") as lib: - lib.define("cpu_only_op(Tensor x) -> Tensor") - lib.impl("cpu_only_op", lambda x: x * 2, "CPU") - - cpu_kernel = torch.library.get_kernel( - "test_invalid_kernel::cpu_only_op", "CPU" - ) - self.assertIsNotNone(cpu_kernel) - - # CUDA should fail at the isValid() check since no CUDA kernel exists - with self.assertRaisesRegex( - RuntimeError, "no kernel for CUDA for test_invalid_kernel::cpu_only_op" - ): - torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA") - class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index c44d5e5d41454..17e699e04e589 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -15,7 +15,6 @@ ) from torch.testing._internal.common_utils import ( run_tests, - skipIfRocm, TEST_WITH_TORCHDYNAMO, TestCase, ) @@ -463,7 +462,6 @@ def get_flops( self.assertExpectedInline(str(flops_fw_bw_math), """805306368""") self.assertExpectedInline(str(flops_fw_bw_efficient), """939524096""") - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -683,7 +681,6 @@ def split_tensor(x): ), ) - @skipIfRocm # Nested tensor @unittest.skipIf(not HAS_CUDA, "CUDA not available") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, diff --git a/test/test_linalg.py b/test/test_linalg.py index 0f6c8f207421b..31d4e0d1d92d5 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -109,22 +109,6 @@ def get_tunableop_untuned_filename(): return untuned_filename class TestLinalg(TestCase): - @contextlib.contextmanager - def _hip_allow_tf32(self): - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Environment variable will be removed in the future. - import os - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" - - try: - yield - finally: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] - def setUp(self): super().setUp() torch.backends.cuda.matmul.allow_tf32 = False @@ -5542,13 +5526,8 @@ def test_scaled_gemm_tunableop(self, device, dtype): @runOnRocmArch(MI300_ARCH) @dtypes(torch.float) def test_tf32_tunableop(self, device, dtype): - # Test TunableOp with TF32. Supported by hipblasLT on MI300+. - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Eventually this flag will go away. - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - try: - with self._tunableop_ctx(), tf32_ctx(): + with self._tunableop_ctx(): torch.backends.cuda.matmul.allow_tf32 = True torch.cuda.tunable.set_rotating_buffer_size(0) @@ -5611,13 +5590,8 @@ def test_tf32_offline_tunableop(self, device, dtype): # This test is the offline version of test_tf32_tunableop import os - # Test TunableOp with TF32. Supported by hipblasLT on MI300+. - # for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new - # and only for MI300+. Eventually this flag will go away. - tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext - try: - with self._tunableop_ctx(), tf32_ctx(): + with self._tunableop_ctx(): torch.backends.cuda.matmul.allow_tf32 = True ordinal = torch.cuda.current_device() torch.cuda.tunable.set_rotating_buffer_size(0) diff --git a/test/test_nn.py b/test/test_nn.py index c17f7cb668b6f..d5c245c5887d2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -39,7 +39,7 @@ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ skipIfTorchDynamo, gcIfJetson, set_default_dtype from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ - PLATFORM_SUPPORTS_FLASH_ATTENTION, _get_torch_rocm_version + _get_torch_rocm_version from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input @@ -3166,7 +3166,6 @@ def perm_fn(x): [2.42240309, 0.0354595, -0.60659063, -0.05378816]]])) torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0) - @skipIfRocm(msg='Large numerical errors') def test_transformerdecoder(self): def get_a_test_layer(use_cuda, activation, batch_first=False): d_model = 4 @@ -12998,8 +12997,6 @@ def test_skip_init(self, device): @dtypes(torch.float) @dtypesIfCUDA(torch.double, torch.float, torch.half) def test_transformerencoderlayer(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer d_model = 4 nhead = 2 @@ -13221,8 +13218,6 @@ def test_transformerencoderlayer_fast_path(self, device, dtype): @dtypes(torch.float) @dtypesIfCUDA(torch.half, torch.float) def test_transformerencoderlayer_gelu(self, device, dtype): - if TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION and dtype == torch.half: - self.skipTest("Skip on ROCM due to Flash Attention tolerances") # this is a deterministic test for TransformerEncoderLayer with gelu activation d_model = 4 nhead = 2 diff --git a/test/test_transformers.py b/test/test_transformers.py index 5b240e1f046c9..b2a3959a50429 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -51,7 +51,6 @@ PLATFORM_SUPPORTS_CUDNN_ATTENTION, tf32_on_and_off, tf32_enabled, - ROCM_VERSION, ) if TEST_FAIRSEQ: @@ -340,14 +339,11 @@ def test_train_with_pad_and_catch_error(self, device): l1_bool = nn.L1Loss()(test_train_bool[:, 0:2, :], test_eval_bool[:, 0:2, :]).item() self.assertTrue(l1_bool < 1e-4, "Eval/Train difference in pad_mask BOOL") - @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.001) @parametrize("attn_mask_dim", [2, 3, None]) @parametrize("key_padding_mask_dim", [2, None]) @parametrize("mask_dtype", [torch.bool, torch.float32]) def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_padding_mask_dim, mask_dtype): - if TEST_WITH_ROCM: - if attn_mask_dim is not None and mask_dtype == torch.bool: - self.skipTest("boolean mask is not fully supported on ROCm yet.") # MHA converts all with torch.no_grad(): B = 2 @@ -430,8 +426,7 @@ def hook(module, inputs, output): # remove hook handle.remove() - @skipIfRocm - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.0021 if TEST_WITH_ROCM else 0.001) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) @@ -524,7 +519,7 @@ def test_transformerencoder_fastpath(self, device, use_torchscript, enable_neste slowpath_output = slowpath_output.masked_fill(src_key_padding_mask.unsqueeze(-1), 0) self.assertEqual(fastpath_output_expanded, slowpath_output) - @tf32_on_and_off(0.001, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.001) @parametrize("with_no_grad", [True, False]) @parametrize("training", [True, False]) @parametrize("enable_nested_tensor", [False]) @@ -1110,7 +1105,7 @@ def forward( return_all_hiddens=False, )[0] - @tf32_on_and_off(0.003, only_if=(not TEST_WITH_ROCM or ROCM_VERSION < (7, 0))) + @tf32_on_and_off(0.003) @parametrize("input_dim,attn_mask_dim,is_causal", [(3, None, False), (3, 2, False), (3, 2, True), (3, 3, False), (3, 3, True), (4, None, False), (4, 2, False), (4, 2, True), (4, 4, False), (4, 4, True)], @@ -1421,7 +1416,6 @@ def ones_tensor(*shape): _ = mha_f(qkv_f, qkv_f, qkv_f, attn_mask=mask, need_weights=False, is_causal=True) torch.cuda.synchronize() - @skipIfRocm # Missing EFFICIENT_ATTENTION @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware" ) @@ -1714,7 +1708,7 @@ def test_unaligned_tensors(self, device): make_tensor = partial(torch.rand, size, device=device, dtype=dtype) q, k, v = make_tensor(), make_tensor(), make_tensor() with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): - ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext() + ctxmgr = self.assertRaises(RuntimeError) with ctxmgr: torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) @@ -2612,7 +2606,6 @@ def convert_flash_attn_S_to_softmax( S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) return S_converted[:, :, :seqlen_q, :seqlen_k] - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2636,7 +2629,6 @@ def test_cudnn_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_gqa(self, device): batch = 4 @@ -2660,7 +2652,6 @@ def test_cudnn_attention_gqa(self, device): self.assertEqual(output_math, output_cudnn) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_cudnn_attention_d256_heuristic(self, device): dtype = torch.bfloat16 @@ -2691,7 +2682,6 @@ def test(): with self.assertRaisesRegex(RuntimeError, "No available kernel."): test() - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") def test_fused_attention_different_dk_dv(self, device): dtype = torch.bfloat16 @@ -2715,7 +2705,7 @@ def test_fused_attention_different_dk_dv(self, device): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2) - @skipIfRocm # No cuDNN Attention + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") @unittest.skipIf(True, "broken as of cuDNN 9.10") def test_cudnn_attention_fail_d128(self, device): # Test that cuDNN attention dispatching correctly bails out on d > 128 @@ -2737,7 +2727,6 @@ def test_cudnn_attention_fail_d128(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel."): torch.nn.functional.scaled_dot_product_attention(q, k, v) - @skipIfRocm(msg="No cuDNN on ROCm") @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_trivial_output_transpose(self, device): # see also: https://github.com/pytorch/pytorch/issues/134001 @@ -2753,7 +2742,6 @@ def test_cudnn_attention_trivial_output_transpose(self, device): o.backward(o) torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3) - @skipIfRocm # No cuDNN Attention @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_nonmodulo64seqlen(self, device): # see also: https://github.com/pytorch/pytorch/issues/137347 @@ -2793,7 +2781,6 @@ def test_cudnn_attention_nonmodulo64seqlen(self, device): torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) - @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_preserves_query_layout(self, device): @@ -2823,7 +2810,6 @@ def test_attention(backend: SDPBackend, permute_order: list[list[int]]): for permute_order in permute_orders: test_attention(SDPBackend.CUDNN_ATTENTION, list(permute_order) + [3]) - @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_compiles(self): q = torch.randn(2, 8, 1024, 128, dtype=torch.half, device='cuda', requires_grad=True) @@ -3242,7 +3228,6 @@ def test_sdp_choice_with_determinism(self, device, warn_only): with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value - @skipIfRocm @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") diff --git a/third_party/fbgemm b/third_party/fbgemm index 4b39c551efe15..3cefe0564a8c3 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 4b39c551efe15e6bbade20565b0ceb2d8ce3352d +Subproject commit 3cefe0564a8c3de514a152d40a2b4770f2ee5be0 diff --git a/tools/linter/dictionary.txt b/tools/linter/dictionary.txt index 706881a8f10f6..c4a250db04836 100644 --- a/tools/linter/dictionary.txt +++ b/tools/linter/dictionary.txt @@ -12,6 +12,7 @@ BU contiguities contiguity coo +DEPENDEES deser din dout diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 80437aa1d833e..5fe3f7e178b73 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1695,11 +1695,6 @@ class _DispatchModule: _after_ADInplaceOrView_keyset: DispatchKeySet _after_autograd_keyset: DispatchKeySet -class _SafeKernelFunction: - def call_boxed(self, keyset: DispatchKeySet, *args, **kwargs) -> Any: ... - @property - def op_handle(self) -> _DispatchOperatorHandle: ... - def _dispatch_library( kind: str, name: str, @@ -1737,10 +1732,6 @@ def _dispatch_has_computed_kernel_for_dispatch_key( name: str, dispatch: _dispatchkey, ) -> _bool: ... -def _dispatch_get_computed_kernel_for_dispatch_key( - name: str, - dispatch: _dispatchkey, -) -> _SafeKernelFunction: ... def _dispatch_find_dangling_impls() -> list[str]: ... def _dispatch_get_all_op_names() -> list[str]: ... def _dispatch_tls_set_dispatch_key_excluded( diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index 417fac7b4f634..2189e44f9e246 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -232,6 +232,18 @@ def should_use_persistent_reduction( features.reduction_numel, threshold ) # type: ignore[arg-types] + @staticmethod + def want_no_x_dim(features: SIMDKernelFeatures) -> bool: + """ + Heuristic to decide if we should drop the X dimension from a persistent reduction kernel. + So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1. + Strangely this is faster than a [1, RBLOCK] block in some cases. + + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. + """ + return False + @staticmethod def reduction_split_factor( device: torch.device, diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 175ea55ec3af2..17a336cc3cf2e 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1101,11 +1101,17 @@ def relu(x): @staticmethod def minimum(a, b): - return f"triton_helpers.minimum({a}, {b})" + if torch.version.hip: + return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)" + else: + return f"triton_helpers.minimum({a}, {b})" @staticmethod def maximum(a, b): - return f"triton_helpers.maximum({a}, {b})" + if torch.version.hip: + return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)" + else: + return f"triton_helpers.maximum({a}, {b})" @staticmethod def where(a, b, c): @@ -1291,7 +1297,10 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() def rsqrt(x): - return f"libdevice.rsqrt({x})" + if torch.version.hip: + return f"tl.rsqrt({x})" + else: + return f"libdevice.rsqrt({x})" @staticmethod @maybe_upcast_float32() @@ -1306,7 +1315,7 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.tanh({x})" + return f"libdevice.fast_tanhf({x})" @staticmethod @maybe_upcast_float32() @@ -2030,12 +2039,11 @@ def should_use_persistent_reduction(self) -> bool: ) def want_no_x_dim(self): - return ( - self.persistent_reduction - and len(self.numels) == self.num_reduction_dims + 1 - and self.fixed_config - and self.fixed_config["XBLOCK"] == 1 - ) + """ + ROCm branch change: Remove want_no_x_dim for persistent reduction. + Inductor benchmarks show no perf advantage and simplifies autotune flow. + """ + return False @property def assert_function(self) -> str: @@ -3789,8 +3797,9 @@ def codegen_body(self): loop_end = ( "rsplit_end" if self.cooperative_reduction else f"{prefix}numel" ) + num_stages = ", num_stages = 2" if torch.version.hip else "" self.body.writeline( - f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):" + f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):" ) with self.body.indent(offset=level + 1): self.iteration_ranges_codegen_header(tree, self.body) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index dc2392119cc51..94a905e4211ce 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -614,7 +614,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index f6921a057ba0f..cbe2bb44f6c86 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1391,7 +1391,7 @@ class triton: # So far we see a fixed 8 spilled registers for kernels using sin/cos. # Raise the threshold to 16 to be safe. # We should revisit this once we understand more of the source of register spills. - spill_threshold: int = 16 + spill_threshold: int = 32 if torch.version.hip else 16 # Generate code containing the newer tl.make_block_ptr() API for loads/store use_block_ptr = False @@ -1442,6 +1442,15 @@ class triton: os.environ.get("TORCHINDUCTOR_DECOMPOSE_K_THRESHOLD", "32") ) + # Map for storing the amount of kernel runs with dumped imput tensors + # Based on hash of Triton source code to avoid bloating the folder + kernel_dump_occurency_map: dict[str, int] = {} + + # Value for the maximum amount of runs with dumped kernel input tensors + # When the maximum is reached the first values get overwritten + # This ensures the last N runs are saved, where N is this value + max_kernel_dump_occurencies = 3 + class aot_inductor: """ diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index ad7a0d56fc4b1..26b3bcf5cc5cf 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -3,6 +3,7 @@ import itertools import logging from typing import Callable, Optional, TYPE_CHECKING +from functools import lru_cache from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -60,10 +61,16 @@ def get_config_max(self, prefix: str) -> int: size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None return min(max_block, size_hint) if size_hint is not None else max_block + @lru_cache(maxsize=1) def get_warpsmax(self): - # Currently, CUDA has a maximum of 1024 threads, so 32 is the max - # number of warps. - return 1024 // 32 + # CUDA/ROCm has a maximum of 1024 threads per block + from torch.cuda import current_device, get_device_properties, is_available + + warp_size = ( + get_device_properties(current_device()).warp_size if is_available() else 32 + ) + + return 1024 // warp_size def cache_benchmark_result(self, config, timing): self.cached_benchmark_results[triton_config_to_hashable(config)] = timing diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 15b86b1b3d1ae..a1a0a792c9b84 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -13,7 +13,7 @@ # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values # NOTE: if these fail asserts submit a PR to increase them TRITON_MAX_BLOCK = { - "X": 4096, + "X": 8192, "Y": 1024, "Z": 1024, "R0_": 4096 * 16, # * 16 is multi-kernel only diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 547fad5222465..ffcfb98a6bf32 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -34,6 +34,7 @@ from torch._environment import is_fbcode from torch._prims_common import compute_required_storage_length from torch.utils._ordered_set import OrderedSet +from torch._inductor.config import triton as inuctor_triton_config from ..triton_bundler import TritonBundler from ..utils import prefix_is_reduction, triton_version_uses_attrs_dict @@ -223,6 +224,39 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): f.write(f"{kernel_name} | {args_str} | {grid!r}\n") +def _dump_launch_tensors(args, kernel_path, kernel_hash, kernel_name): + tensor_list = [arg for arg in args if isinstance(arg, torch.Tensor)] + + run_index = 0 + + # Some kernels don't have path and hash stored + # Using only the name to differentiate between those + if not kernel_path: + kernel_hash = kernel_name + + # Saving only the last N runs of the kernels to avoid bloating the folder + if kernel_hash in inuctor_triton_config.kernel_dump_occurency_map: + run_index = inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] + 1 + + if run_index >= inuctor_triton_config.max_kernel_dump_occurencies: + run_index = 0 + + inuctor_triton_config.kernel_dump_occurency_map[kernel_hash] = run_index + + # Default path for kernels with no hash + if not kernel_path: + directory_path = "/tmp/torchinductor_root/unhashed_kernel_inputs" + else: + directory_path = os.path.dirname(kernel_path) + directory_path = f"{directory_path}/{kernel_name}_run_{run_index}" + os.makedirs(directory_path, exist_ok=True) + + tensor_index = 0 + for tensor in tensor_list: + torch.save(tensor, f"{directory_path}/tensor_{tensor_index}.pt") + tensor_index +=1 + + def check_autotune_cache( configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any] ) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]: @@ -367,6 +401,10 @@ def __init__( self.dump_launch_params = ( os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" ) + self.dump_launch_tensors = ( + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_TENSORS", "0") == "1" + ) + self.kernels_to_dump = os.environ.get("TORCHINDUCTOR_KERNELS_TO_DUMP", "").split(",") self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" @@ -838,7 +876,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs): # for some (complicated) custom Triton kernels, a register-spilling # config may yield the best latency. if not self.custom_kernel and launcher.n_spills > self.inductor_meta.get( - "spill_threshold", 16 + "spill_threshold", 32 if torch.version.hip else 16 ): log.debug( "Skip config %s because of register spilling: %d", @@ -1291,6 +1329,11 @@ def alloc_fn(size: int, align: int, stream: Optional[int]): new_args, grid = self._interpret_args_grid(args, launcher.config) _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) + if self.dump_launch_tensors: + # Check the kernel name if the list was provided + if not self.kernels_to_dump or any(kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump): + _dump_launch_tensors(args, self.filename, self.kernel_hash, self.fn.__name__) + # it is faster than entering and exiting a context manager, even if the context # manager is a nullcontext. if autograd_profiler._is_profiler_enabled: @@ -2163,6 +2206,9 @@ def triton_config( num_stages=1, num_elements_per_warp=256, min_elem_per_thread=0, + num_warps=None, + matrix_instr=None, + waves_per_eu=None, ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2219,9 +2265,11 @@ def triton_config( ): z *= 2 - num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 - ) + # Calculate num_warps if they are not hard passed to config + if num_warps is None: + num_warps = _num_warps( + conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still # want at least 4 warps if there's enough elements per thread @@ -2251,7 +2299,15 @@ def triton_config( cfg["ZBLOCK"] = z check_max_block(cfg) check_config(cfg, xnumel=xnumel, ynumel=ynumel, znumel=znumel) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + + if torch.version.hip: + if matrix_instr is not None: + config.kwargs["matrix_instr_nonkdim"] = matrix_instr + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: @@ -2299,6 +2355,7 @@ def triton_config_reduction( num_stages=1, num_warps=None, register_intensive=False, + waves_per_eu=None, dynamic_scale_rblock=True, ) -> Config: """ @@ -2343,13 +2400,19 @@ def total_numel() -> int: cfg = _get_config({"x": x, **rnumels}) check_max_block(cfg) check_config(cfg, xnumel=size_hints["x"]) - return InductorConfig( + config = InductorConfig( cfg, num_warps=num_warps, num_stages=num_stages, dynamic_scale_rblock=dynamic_scale_rblock, ) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + + return config + def _get_config(numels: dict[str, int]) -> dict[str, int]: """ @@ -2360,7 +2423,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: def triton_config_tiled_reduction( - size_hints, x, y, r, num_stages=1, register_intensive=False + size_hints, x, y, r, num_stages=1, register_intensive=False, waves_per_eu=None ): """ Construct a tile reduction triton config with some adjustment @@ -2397,7 +2460,11 @@ def total_numel() -> int: ) check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) - return Config(cfg, num_warps=num_warps, num_stages=num_stages) + config = Config(cfg, num_warps=num_warps, num_stages=num_stages) + if torch.version.hip: + if waves_per_eu is not None: + config.kwargs["waves_per_eu"] = waves_per_eu + return config def _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs: list[Config]): @@ -2486,11 +2553,38 @@ def pointwise( triton_config_with_settings( size_hints, bs // 2, num_elements_per_warp=64 ), + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, + 4096, # wrt: better than the max_block for some kernel + ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.append( + triton_config_with_settings( + size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1 + ) + ) # 20% improvement + configs += [ + triton_config_with_settings(size_hints, 2048, num_warps=8, num_stages=2, waves_per_eu=1), # 20% improvement # .. in where? + triton_config_with_settings(size_hints, 4096), # wrt1: better than the max_block for some kernel + triton_config_with_settings(size_hints, 128, num_warps=2, num_stages=2, waves_per_eu=1), + # -> wrt1/t18: 2X improvement: triton_poi_fused_index_put_new_zeros_37, + # triton_poi_fused_index_put_new_zeros_45 + # triton_poi_fused_index_put_new_zeros_49 + # triton_poi_fused_index_put_new_zeros_54 + triton_config_with_settings(size_hints, 128, num_warps=1, num_stages=1), # wri0: 56 us: triton_poi_fused_cat_mul_sigmoid_view_51 + ] if len(size_hints) == 2: + # Only avoiding tuning on TileHint.SQUARE if not on ROCm builds + # ROCm has observed improvement by diverging here if ( - disable_pointwise_autotuning(inductor_meta) or tile_hint == TileHint.SQUARE + disable_pointwise_autotuning(inductor_meta) + or (torch.version.hip is None and tile_hint == TileHint.SQUARE) ) and not ( inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") @@ -2499,13 +2593,36 @@ def pointwise( else: configs = [ triton_config_with_settings(size_hints, 32, 32), + triton_config_with_settings( + size_hints, 64, 32 + ), # better for some kernels triton_config_with_settings(size_hints, 64, 64), # ~8% better for fp16 - triton_config_with_settings(size_hints, 256, 16), + triton_config_with_settings(size_hints, 256, 16), triton_config_with_settings(size_hints, 16, 256), + triton_config_with_settings( + size_hints, 128, 16 + ), # +10% for some kernels + triton_config_with_settings(size_hints, 128, 32), # additional 10% more + triton_config_with_settings( + size_hints, 32, 512 + ), # +30% for some kernels triton_config_with_settings(size_hints, bs, 1), triton_config_with_settings(size_hints, 1, bs), *hinted_configs, ] + if torch.version.hip: + configs += [ # add here + ] + # bypass triton_config_with_settings -> triton_config logic + if "x" in size_hints and "y" in size_hints: + configs += [ + Config({"XBLOCK": 512, "YBLOCK": 8}, num_warps=8), # wrt1/t21 # triton_poi_fused__unsafe_view_add_addmm_cat_clone_permute_split_with_sizes_view_19 + Config({"XBLOCK": 32, "YBLOCK": 128}, num_warps=4), # wrt2: 570us : triton_poi_fused_add_transpose_view_52 + Config({"XBLOCK":64, "YBLOCK": 32}, num_warps=8), # wrt3: 150us: triton_poi_fused__to_copy_add_native_layer_norm_native_layer_norm_backward_permute_view_103 + Config({"XBLOCK":64, "YBLOCK": 256}, num_warps=4), # wri0: 70us: triton_poi_fused_clone_tanh_transpose_19 + Config({"XBLOCK":512, "YBLOCK": 64}, num_warps=8), # wri0: 58us: triton_poi_fused_clone_53 + ] + if len(size_hints) == 3: if disable_pointwise_autotuning(inductor_meta): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] @@ -2544,6 +2661,11 @@ def _reduction_configs( # Convert reductions to 1D, to simplify heuristics. rnumel = get_total_reduction_numel(size_hints) + # Is max autotune enabled + max_autotune_enabled = inductor_meta.get("max_autotune") or inductor_meta.get( + "max_autotune_pointwise" + ) + register_intensive = False MAX_R0_BLOCK = 2048 loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( @@ -2572,6 +2694,7 @@ def make_config( num_stages=1, register_intensive=False, dynamic_scale_rblock=True, + waves_per_eu=None, ): # For 3D case with tiling scores, create an adapted version if "y" in size_hints: @@ -2584,6 +2707,7 @@ def make_config( num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu, ) else: # For other cases, use the original function @@ -2594,6 +2718,7 @@ def make_config( num_warps=num_warps, num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu, dynamic_scale_rblock=dynamic_scale_rblock, ) @@ -2674,33 +2799,45 @@ def outer_config_opt(): ) configs.append(c) + result_configs = [] + # For 3d tiling, default to more autotuning initially - if "y" in size_hints: - pass - elif inductor_meta.get("max_autotune") or inductor_meta.get( - "max_autotune_pointwise" - ): - pass # skip all these cases - elif reduction_hint == ReductionHint.INNER: - return configs + [contiguous_config] - elif reduction_hint == ReductionHint.OUTER: - return configs + [outer_config] - elif reduction_hint == ReductionHint.OUTER_TINY: - return configs + [tiny_config] - if disable_pointwise_autotuning(inductor_meta): - return configs + [make_config(32, 128)] - - return configs + [ - contiguous_config, - outer_config, - tiny_config, - make_config(64, 64), - make_config(8, 512), - # halve the XBLOCK/Rn_BLOCK compared to outer_config - # TODO: this may only be beneficial when each iteration of the reduction - # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 - make_config(64, 4, num_warps=8), - ] + if not (max_autotune_enabled or "y" in size_hints): + if reduction_hint == ReductionHint.INNER: + result_configs = configs + [contiguous_config] + elif reduction_hint == ReductionHint.OUTER: + result_configs = configs + [outer_config] + elif reduction_hint == ReductionHint.OUTER_TINY: + result_configs = configs + [tiny_config] + else: + result_configs = configs + [make_config(32, 128)] + else: + result_configs = configs + [ + contiguous_config, + outer_config, + tiny_config, + make_config(64, 64), + make_config(8, 512), + # halve the XBLOCK/Rn_BLOCK compared to outer_config + # TODO: this may only be beneficial when each iteration of the reduction + # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72 + make_config(64, 4, num_warps=8), + ] + + if torch.version.hip: + result_configs.extend( + [ + make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), + make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), + make_config(128, 4, num_warps=2, num_stages=1, waves_per_eu=1), # wrt2: 3X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8 + make_config(1, 512, num_warps=8, num_stages=1, waves_per_eu=1), # wrt2: 2X # triton_red_fused_index_add_index_select_mul_native_layer_norm_native_layer_norm_backward_new_zeros_sigmoid_8-v2 & v3 & v4 + make_config(1, 4096, num_warps=8, num_stages=1, waves_per_eu=1), # wrt3: 380 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_slice_tanh_tanh_backward_153 + make_config(64, 128, num_warps=4, num_stages=1, waves_per_eu=1), # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_add_addmm_cat_clone_native_layer_norm_permute_tanh_view_16 + make_config(2, 2048, num_warps=8, num_stages=1, waves_per_eu=1) # wrt3: 170 us # triton_red_fused__to_copy__unsafe_view_clone_native_layer_norm_native_layer_norm_backward_permute_tanh_tanh_backward_29 + ] + ) + + return result_configs def match_target_block_product( @@ -2758,6 +2895,7 @@ def adapt_config_for_tiling( num_stages=1, register_intensive=False, persistent_reduction=False, + waves_per_eu=None, ) -> Config: """ Create an adapted configuration based on tiling scores, @@ -2776,6 +2914,7 @@ def adapt_config_for_tiling( block_sizes["r0_"], num_stages=num_stages, register_intensive=register_intensive, + waves_per_eu=waves_per_eu, ) @@ -2868,13 +3007,25 @@ def _persistent_reduction_configs( ): xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) + loads_and_stores = inductor_meta.get("num_load", 0) + inductor_meta.get( + "num_store", 0 + ) MAX_PERSISTENT_BLOCK_NUMEL = 4096 + max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ) + + if torch.version.hip: + xblock_vals = [1, 4, 8, 16, 32, 64, 128, 256] + else: + xblock_vals = [1, 8, 32, 128] + if "y" not in size_hints: configs = [ triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True) - for xblock in (1, 8, 32, 128) + for xblock in xblock_vals if xblock == 1 or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel) ] @@ -2882,7 +3033,7 @@ def _persistent_reduction_configs( configs = [] assert "tiling_scores" in inductor_meta x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")} - for target_block_size in (1, 8, 32, 64, 128): + for target_block_size in xblock_vals: if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL: continue @@ -2895,22 +3046,49 @@ def _persistent_reduction_configs( ) ) + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + # defer to more autotuning, initially if "y" in size_hints: pass # TODO(jansel): we should be able to improve these heuristics - elif reduction_hint == ReductionHint.INNER and rnumel >= 256: - configs = configs[:1] - elif reduction_hint == ReductionHint.OUTER: - configs = configs[-1:] - elif reduction_hint == ReductionHint.OUTER_TINY: - configs = [ - triton_config_reduction( - size_hints, - 2 * (256 // rnumel) if rnumel <= 256 else 1, - rnumel, - ) - ] + elif not max_autotune_enabled: # Do not filter configs when tuning + if reduction_hint == ReductionHint.INNER: + if rnumel > 1024: + configs = configs[:1] + else: + x_block = 8 + if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256): + # If loads/stores greater than 5, a lot of register pressure + # rnumel < 256 means no vectorized loads if we split up r dim + # so xblock still needs to be larger + x_block = 1 + + configs = [ + triton_config_reduction( + size_hints, + x_block, + rnumel, + register_intensive=True, + ) + ] + + elif reduction_hint == ReductionHint.OUTER: + configs = configs[-1:] + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs + else: + # If autotune is enabled append tiny configs + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) + for c in configs: # we don't need Rn_BLOCK for persistent reduction for prefix in size_hints: @@ -3102,13 +3280,24 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + + # Naive autotuning path for num_warps + if disable_pointwise_autotuning(inductor_meta) and not ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h index 9728d27d4d79b..0ac2c79d1e98a 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h +++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h @@ -260,7 +260,7 @@ typedef __half half; )"; #endif -#if defined(USE_ROCM) +#if defined(USE_ROCM) && ROCM_VERSION < 70000 constexpr auto bfloat16_support_literal = R"( #ifndef __align__ @@ -317,6 +317,75 @@ __device__ __nv_bfloat16 __float2bfloat16(const float a) { return val; } +__device__ float __bfloat162float(const __nv_bfloat16 a) { + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(a.__x) << 16}; + return u.fp32; +} +#endif /* defined(__cplusplus) */ +)"; +#elif defined(USE_ROCM) && ROCM_VERSION >= 70000 +constexpr auto bfloat16_support_literal = + R"( +#ifndef __align__ +#define __align__(x) __attribute__((aligned(x))) +#endif + +typedef unsigned int uint32_t; + +typedef struct __align__(2) { + unsigned short x; +} +__nv_bfloat16_raw; + +#if defined(__cplusplus) +struct __align__(2) __nv_bfloat16 { + __host__ __device__ __nv_bfloat16() {} + + __host__ __device__ __nv_bfloat16& operator=(const __nv_bfloat16_raw& hr) { + __x = hr.x; + return *this; + } + + unsigned short __x; +}; + +__device__ unsigned short __internal_float2bfloat16( + const float f, + unsigned int& sign, + unsigned int& remainder) { + unsigned int x; + + x = __float_as_uint(f); + + if ((x & 0x7fffffffU) > 0x7f800000U) { + sign = 0U; + remainder = 0U; + return static_cast(0x7fffU); + } + sign = x >> 31; + remainder = x << 16; + return static_cast(x >> 16); +} + +/* Definitions of intrinsics */ +__device__ __nv_bfloat16 __float2bfloat16(const float a) { + __nv_bfloat16 val; + __nv_bfloat16_raw r; + unsigned int sign; + unsigned int remainder; + r.x = __internal_float2bfloat16(a, sign, remainder); + if ((remainder > 0x80000000U) || + ((remainder == 0x80000000U) && ((r.x & 0x1U) != 0U))) { + r.x++; + } + val = r; + return val; +} + __device__ float __bfloat162float(const __nv_bfloat16 a) { union { diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 9d6eb35c71789..07fa4ea5e1dd7 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -620,43 +620,6 @@ void initDispatchBindings(PyObject* module) { c10::parseDispatchKey(dispatch)); }); - // Bind SafeKernelFunction class - py::class_(m, "_SafeKernelFunction") - .def( - "call_boxed", - [](const c10::SafeKernelFunction& self, - c10::DispatchKeySet keyset, - py::args args, - const py::kwargs& kwargs) { - const auto& op = self.opHandle(); - auto stack = torch::jit::createStackForSchema( - op.schema(), - std::move(args), - kwargs, - /*self=*/std::nullopt); - self.callBoxed(op, keyset, &stack); - return torch::jit::createPyObjectForStack(std::move(stack)); - }) - .def( - "__repr__", - [](const c10::SafeKernelFunction& self) { - return "SafeKernelFunction(debug='" + self.debug() + "')"; - }) - .def_property_readonly( - "op_handle", [](const c10::SafeKernelFunction& self) -> py::object { - return py::cast(self.opHandle()); - }); - - m.def( - "_dispatch_get_computed_kernel_for_dispatch_key", - [](const char* name, - c10::DispatchKey dispatch) -> c10::SafeKernelFunction { - auto op = - c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); - TORCH_CHECK(op, "operator ", name, " does not exist"); - return op->getComputedKernelForDispatchKey(dispatch); - }); - m.def("_dispatch_find_dangling_impls", []() -> std::vector { auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index c3982c33315e2..d1ac7fad7480b 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -591,7 +591,6 @@ def _process_single_offline_gemm(untuned_gemm_line: str, gpu_id: int) -> None: transA = layout[1] == "T" dtype = dtype_dict.get(data_type) if data_type == "tf32": - # User must still set HIPBLASLT_ALLOW_TF32=1 torch.backends.cuda.matmul.allow_tf32 = True else: torch.backends.cuda.matmul.allow_tf32 = False diff --git a/torch/library.py b/torch/library.py index d36c181581483..372037f09dbe5 100644 --- a/torch/library.py +++ b/torch/library.py @@ -45,7 +45,6 @@ "register_torch_dispatch", "register_vmap", "get_ctx", - "get_kernel", "custom_op", "triton_op", "wrap_triton", @@ -1476,80 +1475,6 @@ def get_ctx() -> "torch._library.fake_impl.FakeImplCtx": return torch._library.fake_impl.global_ctx_getter() -def get_kernel( - op: _op_identifier, dispatch_key: Union[str, torch.DispatchKey] -) -> torch._C._SafeKernelFunction: - """Returns the computed kernel for a given operator and dispatch key. - - This function retrieves the kernel that would be executed for a given - operator and dispatch key combination. The returned SafeKernelFunction - can be used to call the kernel in a boxed fashion. The intended use - case for this function is to retrieve the original kernel for a given - dispatch key and then register another kernel to the same dispatch key - that calls into the original kernel for certain cases. - - Args: - op: Operator name (along with the overload) or OpOverload object - Can be a string (e.g., "aten::add.Tensor"), an OpOverload, or a CustomOpDef. - dispatch_key (str | torch.DispatchKey): The dispatch key to get the kernel for. - Can be a string (e.g., "CPU", "CUDA") or a DispatchKey enum value. - - Returns: - torch._C._SafeKernelFunction: A safe kernel function that can be used to - call the kernel. - - Raises: - RuntimeError: If the operator does not exist. - - Example: - >>> # Get the CPU kernel for torch.add - >>> kernel = torch.library.get_kernel("aten::add.Tensor", "CPU") - >>> - >>> # You can also use DispatchKey enum - >>> kernel = torch.library.get_kernel("aten::add.Tensor", torch.DispatchKey.CPU) - >>> - >>> # Or use an OpOverload directly - >>> kernel = torch.library.get_kernel(torch.ops.aten.add.Tensor, "CPU") - >>> - >>> # Example: Using get_kernel in a custom op with conditional dispatch - >>> # Get the original kernel for torch.sin - >>> original_sin_kernel = torch.library.get_kernel("aten::sin", "CPU") - >>> - >>> # If input has negative values, use original sin, otherwise return zeros - >>> def conditional_sin_impl(dispatch_keys, x): - >>> if (x < 0).any(): - >>> return original_sin_kernel.call_boxed(dispatch_keys, x) - >>> else: - >>> return torch.zeros_like(x) - >>> - >>> lib = torch.library.Library("aten", "IMPL") - >>> # with_keyset=True so the first argument to the impl is the current DispatchKeySet - >>> which needs to be the first argument to ``kernel.call_boxed`` - >>> lib.impl("sin", conditional_sin_impl, "CPU", with_keyset=True) - >>> - >>> # Test the conditional behavior - >>> x_positive = torch.tensor([1.0, 2.0]) - >>> x_mixed = torch.tensor([-1.0, 2.0]) - >>> torch.sin(x_positive) - tensor([0., 0.]) - >>> torch.sin(x_mixed) - tensor([-0.8415, 0.9093]) - """ - if not isinstance(op, (str, torch._ops.OpOverload)): - raise ValueError(f"get_kernel({op}): got unexpected type for op: {type(op)}") - - if isinstance(op, torch._ops.OpOverload): - op = op._name - - if isinstance(dispatch_key, str): - try: - dispatch_key = torch._C.DispatchKey.__members__[dispatch_key] - except KeyError: - raise ValueError(f"Invalid dispatch key: {dispatch_key}") from None - - return torch._C._dispatch_get_computed_kernel_for_dispatch_key(op, dispatch_key) - - _OPCHECK_DEFAULT_UTILS = ( "test_schema", "test_autograd_registration", diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index be284429114f5..846d2b407684c 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -181,9 +181,6 @@ def tf32_off(): @contextlib.contextmanager def tf32_on(self, tf32_precision=1e-5): - if torch.version.hip: - hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None) - os.environ["HIPBLASLT_ALLOW_TF32"] = "1" old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32 old_precision = self.precision try: @@ -192,11 +189,6 @@ def tf32_on(self, tf32_precision=1e-5): with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True): yield finally: - if torch.version.hip: - if hip_allow_tf32 is not None: - os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32 - else: - del os.environ["HIPBLASLT_ALLOW_TF32"] torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul self.precision = old_precision @@ -246,7 +238,7 @@ def tf32_enabled(): # if device is specified, it will check if device is cuda # if dtype is specified, it will check if dtype is float32 or complex64 # tf32 and fp32 are different only when all the three checks pass -def tf32_on_and_off(tf32_precision=1e-5, only_if=True): +def tf32_on_and_off(tf32_precision=1e-5, *, only_if=True): def with_tf32_disabled(self, function_call): with tf32_off(): function_call() diff --git a/version.txt b/version.txt index 03e905f0db5fe..c8e38b614057b 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.9.0a0 +2.9.0