From e3684b2fe2189cde698eb3fc7604d1fd0ae693fb Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Wed, 25 Mar 2026 13:29:01 -0700 Subject: [PATCH 01/10] [ROCm] Hipdnn core integration Integrate HipDNN with PyTorch when available (requires ROCm 7.12+). Includes: - cmake `USE_HIPDNN` detection - runtime `CUDAHooks::compiledWithHipDNN()` hook - simple `torch.backends.hipdnn` Python module Co-authored-by: Dmitry Nikolaev Signed-off-by: zjgarvey Assisted-by: Claude Opus 4.6 remove macro variable AT_HIPDNN_ENABLED The macro was causing internal build failures. The purpose of the macro is to throw a compile error when the generated header isn't included, but we only use the macro in a file where this is the case anyway, so we might as well directly query the associated preprocessor directive. Signed-off-by: zjgarvey --- aten/src/ATen/CMakeLists.txt | 11 ++++++ aten/src/ATen/Context.cpp | 8 ++++ aten/src/ATen/Context.h | 3 ++ aten/src/ATen/cuda/detail/CUDAHooks.cpp | 8 ++++ aten/src/ATen/cuda/detail/CUDAHooks.h | 1 + aten/src/ATen/detail/CUDAHooksInterface.h | 4 ++ aten/src/ATen/hipdnn/Exceptions.h | 48 +++++++++++++++++++++++ aten/src/ATen/hipdnn/Handle.cpp | 46 ++++++++++++++++++++++ aten/src/ATen/hipdnn/Handle.h | 10 +++++ aten/src/ATen/hipdnn/Types.cpp | 24 ++++++++++++ aten/src/ATen/hipdnn/Types.h | 12 ++++++ aten/src/ATen/hipdnn/Utils.h | 18 +++++++++ aten/src/ATen/hipdnn/hipdnn-wrapper.h | 3 ++ cmake/Dependencies.cmake | 4 ++ cmake/public/LoadHIP.cmake | 20 ++++++++++ torch/CMakeLists.txt | 3 ++ torch/_C/__init__.pyi.in | 3 ++ torch/backends/__init__.py | 1 + torch/backends/hipdnn/__init__.py | 45 +++++++++++++++++++++ torch/csrc/Module.cpp | 32 +++++++++++++++ torch/testing/_internal/common_cuda.py | 1 + 21 files changed, 305 insertions(+) create mode 100644 aten/src/ATen/hipdnn/Exceptions.h create mode 100644 aten/src/ATen/hipdnn/Handle.cpp create mode 100644 aten/src/ATen/hipdnn/Handle.h create mode 100644 aten/src/ATen/hipdnn/Types.cpp create mode 100644 aten/src/ATen/hipdnn/Types.h create mode 100644 aten/src/ATen/hipdnn/Utils.h create mode 100644 aten/src/ATen/hipdnn/hipdnn-wrapper.h create mode 100644 torch/backends/hipdnn/__init__.py diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 06522a9298893..92cd4827d0901 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -102,6 +102,8 @@ file(GLOB hip_nvrtc_stub_h "hip/nvrtc_stub/*.h") file(GLOB hip_nvrtc_stub_cpp "hip/nvrtc_stub/*.cpp") file(GLOB miopen_h "miopen/*.h") file(GLOB miopen_cpp "miopen/*.cpp") +file(GLOB hipdnn_h "hipdnn/*.h") +file(GLOB hipdnn_cpp "hipdnn/*.cpp") file(GLOB mkl_cpp "mkl/*.cpp") file(GLOB mkldnn_cpp "mkldnn/*.cpp") @@ -179,6 +181,7 @@ file(GLOB native_hip_hip "native/hip/*.hip" "native/hip/bgemm_kernels/*.hip") file(GLOB native_hip_cpp "native/hip/*.cpp") file(GLOB native_hip_linalg_cpp "native/hip/linalg/*.cpp") file(GLOB native_miopen_cpp "native/miopen/*.cpp") +file(GLOB native_hipdnn_cpp "native/hipdnn/*.cpp") file(GLOB native_cudnn_hip_cpp "native/cudnn/hip/*.cpp") file(GLOB native_nested_hip_hip "native/nested/hip/*.hip") file(GLOB native_nested_hip_cpp "native/nested/hip/*.cpp") @@ -568,6 +571,7 @@ if(USE_CUDA) ${native_cuda_cpp} ${native_cudnn_cpp} ${native_miopen_cpp} + ${native_hipdnn_cpp} ${native_nested_cuda_cpp} ${native_quantized_cuda_cpp} ${native_quantized_cudnn_cpp} @@ -647,10 +651,14 @@ if(USE_ROCM) ${cuda_generated_sources} ${ATen_HIP_SRCS} ${native_miopen_cpp} + ${native_hipdnn_cpp} ${native_cudnn_hip_cpp} ${miopen_cpp} ${all_hip_cpp} ) + if(USE_HIPDNN) + list(APPEND all_hip_cpp ${hipdnn_cpp}) + endif() endif() if(USE_XPU) @@ -937,6 +945,9 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) if(NOT INTERN_BUILD_MOBILE) list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${native_mtia_h} ${cudnn_h} ${hip_h} ${mtia_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h}) + if(USE_HIPDNN) + list(APPEND INSTALL_HEADERS ${hipdnn_h}) + endif() # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index c342590b58c42..fa97b4d6bc1eb 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -319,6 +319,14 @@ void Context::setImmediateMiopen(bool b) { immediate_miopen = b; } +bool Context::userEnabledHipdnn() const { + return enabled_hipdnn; +} + +void Context::setUserEnabledHipdnn(bool e) { + enabled_hipdnn = e; +} + bool Context::allowTF32CuBLAS() const { bool legacy_allow_tf32 = float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST; bool allow_tf32_new = float32Precision(Float32Backend::CUDA, Float32Op::MATMUL) == Float32Precision::TF32; diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index de6a7dda66d73..8d4468911d380 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -244,6 +244,8 @@ class TORCH_API Context { void setBenchmarkLimitCuDNN(int /*b*/); bool immediateMiopen() const; void setImmediateMiopen(bool /*b*/); + bool userEnabledHipdnn() const; + void setUserEnabledHipdnn(bool e); bool deterministicCuDNN() const; void setDeterministicCuDNN(bool /*b*/); bool deterministicMkldnn() const; @@ -478,6 +480,7 @@ class TORCH_API Context { bool allow_fp16_bf16_reduction_mathSDP = false; bool benchmark_cudnn = false; bool immediate_miopen = false; + bool enabled_hipdnn = false; Float32MatmulPrecision float32_matmul_precision = c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true ? at::Float32MatmulPrecision::HIGH diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 5f81407b1ac03..89136a391ec4f 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -276,6 +276,14 @@ bool CUDAHooks::compiledWithMIOpen() const { return AT_ROCM_ENABLED(); } +bool CUDAHooks::compiledWithHipDNN() const { +#ifdef USE_HIPDNN + return true; +#else + return false; +#endif +} + bool CUDAHooks::supportsDilatedConvolutionWithCuDNN() const { #if AT_CUDNN_ENABLED() if (!hasCUDA()) { diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 6347cc6f40c85..ac30dd3349ef6 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -42,6 +42,7 @@ struct CUDAHooks : public at::CUDAHooksInterface { Allocator* getPinnedMemoryAllocator() const override; bool compiledWithCuDNN() const override; bool compiledWithMIOpen() const override; + bool compiledWithHipDNN() const override; bool supportsDilatedConvolutionWithCuDNN() const override; bool supportsDepthwiseConvolutionWithCuDNN() const override; bool supportsBFloat16ConvolutionWithCuDNNv8() const override; diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index 4ed00372fb778..f07ee9f670b9b 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -154,6 +154,10 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { return false; } + virtual bool compiledWithHipDNN() const { + return false; + } + virtual bool supportsDilatedConvolutionWithCuDNN() const { return false; } diff --git a/aten/src/ATen/hipdnn/Exceptions.h b/aten/src/ATen/hipdnn/Exceptions.h new file mode 100644 index 0000000000000..30eba78b9e0a3 --- /dev/null +++ b/aten/src/ATen/hipdnn/Exceptions.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include + +namespace c10 { + +class HipDNNError : public c10::Error { + using Error::Error; +}; + +} // namespace c10 + +#define HIPDNN_CHECK(EXPR, ...) \ + do { \ + hipdnnStatus_t status = EXPR; \ + if (status != HIPDNN_STATUS_SUCCESS) { \ + if (status == HIPDNN_STATUS_NOT_SUPPORTED) { \ + TORCH_CHECK_WITH( \ + HipDNNError, \ + false, \ + "hipDNN error: ", \ + hipdnnGetErrorString(status), \ + ". This error may appear if you passed in a non-contiguous" \ + " input.", \ + ##__VA_ARGS__); \ + } else { \ + TORCH_CHECK_WITH( \ + HipDNNError, \ + false, \ + "hipDNN error: ", \ + hipdnnGetErrorString(status), \ + ##__VA_ARGS__); \ + } \ + } \ + } while (0) + +#define HIPDNN_FE_CHECK(EXPR) \ + do { \ + auto error_object = EXPR; \ + if (!error_object.is_good()) { \ + TORCH_CHECK_WITH( \ + HipDNNError, \ + false, \ + "hipDNN Frontend error: ", \ + error_object.get_message()); \ + } \ + } while (0) diff --git a/aten/src/ATen/hipdnn/Handle.cpp b/aten/src/ATen/hipdnn/Handle.cpp new file mode 100644 index 0000000000000..3439d26a96e1e --- /dev/null +++ b/aten/src/ATen/hipdnn/Handle.cpp @@ -0,0 +1,46 @@ +#include +#include +#include + +#include +#include +#include + +namespace at::native { +namespace { + +void createHipdnnHandle(hipdnnHandle_t* handle) { + HIPDNN_CHECK(hipdnnCreate(handle)); +} + +void destroyHipdnnHandle(hipdnnHandle_t handle) { + // Intentionally not destroying handle to avoid shutdown ordering issues. + // See comments in the miopen equivalent (Handle.cpp). +} + +using HipDNNPoolType = at::cuda::DeviceThreadHandlePool< + hipdnnHandle_t, + createHipdnnHandle, + destroyHipdnnHandle>; + +} // namespace + +hipdnnHandle_t getHipdnnHandle() { + c10::DeviceIndex device = 0; + AT_CUDA_CHECK(c10::hip::GetDevice(&device)); + + // Thread local PoolWindows are lazily-initialized + // to avoid initialization issues that caused hangs on Windows. + // See: https://github.com/pytorch/pytorch/pull/22405 + // This thread local unique_ptrs will be destroyed when the thread terminates, + // releasing its reserved handles back to the pool. + static auto pool = std::make_shared(); + thread_local std::unique_ptr myPoolWindow( + pool->newPoolWindow()); + + auto handle = myPoolWindow->reserve(device); + HIPDNN_CHECK(hipdnnSetStream(handle, c10::hip::getCurrentHIPStream())); + return handle; +} + +} // namespace at::native diff --git a/aten/src/ATen/hipdnn/Handle.h b/aten/src/ATen/hipdnn/Handle.h new file mode 100644 index 0000000000000..3e9af1e83f9c2 --- /dev/null +++ b/aten/src/ATen/hipdnn/Handle.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API hipdnnHandle_t getHipdnnHandle(); +} // namespace at::native diff --git a/aten/src/ATen/hipdnn/Types.cpp b/aten/src/ATen/hipdnn/Types.cpp new file mode 100644 index 0000000000000..750b3f0fea30e --- /dev/null +++ b/aten/src/ATen/hipdnn/Types.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include + +namespace at::native { + +hipdnn_frontend::DataType getHipdnnDataType(const at::Tensor& tensor) { + switch (tensor.scalar_type()) { + case at::kFloat: + return hipdnn_frontend::DataType::FLOAT; + case at::kHalf: + return hipdnn_frontend::DataType::HALF; + case at::kBFloat16: + return hipdnn_frontend::DataType::BFLOAT16; + default: + TORCH_CHECK( + false, + "getHipdnnDataType() not supported for ", + toString(tensor.scalar_type())); + } +} + +} // namespace at::native diff --git a/aten/src/ATen/hipdnn/Types.h b/aten/src/ATen/hipdnn/Types.h new file mode 100644 index 0000000000000..90b7bd8567dec --- /dev/null +++ b/aten/src/ATen/hipdnn/Types.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include +#include + +namespace at::native { + +TORCH_CUDA_CPP_API hipdnn_frontend::DataType getHipdnnDataType( + const at::Tensor& tensor); + +} // namespace at::native diff --git a/aten/src/ATen/hipdnn/Utils.h b/aten/src/ATen/hipdnn/Utils.h new file mode 100644 index 0000000000000..a9f236fd630a3 --- /dev/null +++ b/aten/src/ATen/hipdnn/Utils.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native { + +inline std::shared_ptr +createTensorAttributes(const Tensor& t) { + auto tensor = std::make_shared(); + tensor->set_dim(t.sizes().vec()).set_data_type(getHipdnnDataType(t)); + tensor->set_stride(t.strides().vec()); + return tensor; +} + +} // namespace at::native diff --git a/aten/src/ATen/hipdnn/hipdnn-wrapper.h b/aten/src/ATen/hipdnn/hipdnn-wrapper.h new file mode 100644 index 0000000000000..877413a0c9ccd --- /dev/null +++ b/aten/src/ATen/hipdnn/hipdnn-wrapper.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 70c5d3de4cf23..d8fad4621bf90 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1086,6 +1086,10 @@ if(USE_ROCM) set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS hip::amdhip64 MIOpen hiprtc::hiprtc) # libroctx will be linked in with MIOpen + if(USE_HIPDNN) + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS hipdnn_frontend) + list(APPEND HIP_CXX_FLAGS -DUSE_HIPDNN) + endif() # Math libraries list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index a87b16f5ba889..39154dcad2694 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -209,6 +209,26 @@ if(HIP_FOUND) find_package_and_print_version(hipsolver REQUIRED) find_package_and_print_version(rocsolver REQUIRED) find_package_and_print_version(rocshmem) + # hipdnn packages export include dirs via target INTERFACE_INCLUDE_DIRECTORIES + # rather than the ${PACKAGE_NAME}_INCLUDE_DIR variable that + # find_package_and_print_version checks, so we call find_package directly. + find_package(hipdnn_frontend CONFIG) + if(hipdnn_frontend_FOUND) + message(STATUS "hipdnn_frontend VERSION: ${hipdnn_frontend_VERSION}") + get_target_property(_hipdnn_fe_includes hipdnn_frontend INTERFACE_INCLUDE_DIRECTORIES) + if(_hipdnn_fe_includes) + list(APPEND ROCM_INCLUDE_DIRS ${_hipdnn_fe_includes}) + set(USE_HIPDNN ON) + message(STATUS "Found hipDNN, enabling USE_HIPDNN") + else() + set(USE_HIPDNN OFF) + message(STATUS "hipDNN found but missing include directories, disabling USE_HIPDNN") + endif() + else() + set(USE_HIPDNN OFF) + message(STATUS "hipDNN not found, disabling USE_HIPDNN") + endif() + # workaround cmake 4 build issue if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") message(WARNING "Work around hiprtc cmake failure for cmake >= 4") diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 471fe0c9c1f0b..bb97ab91f234d 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -164,6 +164,9 @@ if(USE_ROCM) USE_ROCM __HIP_PLATFORM_AMD__ ) + if(USE_HIPDNN) + list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_HIPDNN) + endif() if(NOT WIN32) list(APPEND TORCH_PYTHON_LINK_LIBRARIES ${ROCM_ROCTX_LIB}) endif() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 87eef3e01fbc7..d49493c2d75c1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1203,6 +1203,8 @@ def set_num_interop_threads( ) -> None: ... # THPModule_setNumInteropThreads def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN +def _get_hipdnn_enabled() -> _bool: ... # THPModule_userEnabledHipdnn +def _set_hipdnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledHipdnn def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash def _get_fa3_sdp_enabled() -> _bool: ... # THPModule_userEnabledFA3SDP @@ -1464,6 +1466,7 @@ _has_xpu: _bool _has_mkldnn: _bool _has_mkldnn_acl: _bool _has_cudnn: _bool +_has_hipdnn: _bool _has_cusparselt: _bool has_spectral: _bool _GLIBCXX_USE_CXX11_ABI: _bool diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index f07a4797c64da..437f30e4afbf2 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -126,6 +126,7 @@ class GenericModule(PropModule): cuda as cuda, cudnn as cudnn, cusparselt as cusparselt, + hipdnn as hipdnn, kleidiai as kleidiai, mha as mha, miopen as miopen, diff --git a/torch/backends/hipdnn/__init__.py b/torch/backends/hipdnn/__init__.py new file mode 100644 index 0000000000000..c09442a1a9c05 --- /dev/null +++ b/torch/backends/hipdnn/__init__.py @@ -0,0 +1,45 @@ +# mypy: allow-untyped-defs +import sys +from contextlib import contextmanager + +import torch +from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule + + +def is_available(): + return torch._C._has_hipdnn + + +def set_flags( + _enabled=None, +): + orig_flags = (torch._C._get_hipdnn_enabled(),) + if _enabled is not None: + torch._C._set_hipdnn_enabled(_enabled) + return orig_flags + + +@contextmanager +def flags( + enabled=None, +): + with __allow_nonbracketed_mutation(): + orig_flags = set_flags( + enabled, + ) + try: + yield + finally: + # recover the previous values + with __allow_nonbracketed_mutation(): + set_flags(*orig_flags) + + +class HipdnnModule(PropModule): + enabled = ContextProp(torch._C._get_hipdnn_enabled, torch._C._set_hipdnn_enabled) + + +sys.modules[__name__] = HipdnnModule(sys.modules[__name__], __name__) + +# Add type annotation for the replaced module +enabled: bool diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index f3aee904e530c..4324bffc48941 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1333,6 +1333,29 @@ static PyObject* THPModule_immediateMiopen( Py_RETURN_FALSE; } +static PyObject* THPModule_setUserEnabledHipdnn( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + PyBool_Check(arg), + "set_hipdnn_enabled expects a bool, " + "but got ", + THPUtils_typename(arg)); + at::globalContext().setUserEnabledHipdnn(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_userEnabledHipdnn( + PyObject* _unused, + PyObject* noargs) { + if (at::globalContext().userEnabledHipdnn()) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + static PyObject* THPModule_setAllowTF32CuBLAS( PyObject* _unused, PyObject* arg) { @@ -1938,6 +1961,8 @@ static std::initializer_list TorchMethods = { {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr}, {"_get_miopen_immediate", THPModule_immediateMiopen, METH_NOARGS, nullptr}, {"_set_miopen_immediate", THPModule_setImmediateMiopen, METH_O, nullptr}, + {"_get_hipdnn_enabled", THPModule_userEnabledHipdnn, METH_NOARGS, nullptr}, + {"_set_hipdnn_enabled", THPModule_setUserEnabledHipdnn, METH_O, nullptr}, {"_get_cudnn_deterministic", THPModule_deterministicCuDNN, METH_NOARGS, @@ -2419,6 +2444,13 @@ PyObject* initModule() { #endif ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn)); +#if defined(USE_HIPDNN) + PyObject* has_hipdnn = Py_True; +#else + PyObject* has_hipdnn = Py_False; +#endif + ASSERT_TRUE(set_module_attr("_has_hipdnn", has_hipdnn)); + #if defined(USE_CUSPARSELT) || defined(USE_ROCM) PyObject* has_cusparselt = Py_True; #else diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 8c2f0fc64fef7..e7116f4eccc4e 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -24,6 +24,7 @@ TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE))) TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0) +TEST_HIPDNN = LazyVal(lambda: TEST_CUDA and torch.backends.hipdnn.is_available()) ROCM_VERSION = LazyVal(lambda : tuple(int(v) for v in torch.version.hip.split('.')[:2]) if torch.version.hip else (0, 0)) SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3)) From b16a986dc2e6bea930e27cabea38b1db3483c3dd Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 25 Mar 2026 15:00:31 -0500 Subject: [PATCH 02/10] [ROCm] Add hipdnn convolution support Implements forward and backward convolution (2D/3D) through the hipDNN frontend graph API, providing an alternative to the MIOpen backend on ROCm. - Graph-cached convolution: Forward (fprop), backward-data (dgrad), and backward-weight (wgrad) via hipDNN frontend graphs with a thread-local LRU cache (`ParamsLRUCache`) to amortize `graph->build()` cost - Dispatch integration: New `ConvBackend::Hipdnn` and `ConvBackend::HipdnnTranspose` variants wired through backend selection, memory format selection, forward/backward switches, and Python enum exposure. hipDNN takes priority over MIOpen when `torch.backends.hipdnn.enabled` is `True` Signed-off-by: zjgarvey Assisted-by: Claude Opus 4.6 --- aten/src/ATen/native/ConvUtils.h | 48 ++ aten/src/ATen/native/Convolution.cpp | 59 ++ aten/src/ATen/native/hipdnn/Conv_hipdnn.cpp | 714 ++++++++++++++++++ test/nn/test_convolution.py | 206 +++++ torch/csrc/Module.cpp | 4 +- torch/testing/_internal/common_device_type.py | 14 + 6 files changed, 1044 insertions(+), 1 deletion(-) create mode 100644 aten/src/ATen/native/hipdnn/Conv_hipdnn.cpp diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index 8c0771ddfdcc6..2c3a320825d4a 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -42,6 +42,23 @@ using miopen_depthwise_convolution_backward_fn = std::tuple); DECLARE_DISPATCH(miopen_depthwise_convolution_backward_fn, miopen_depthwise_convolution_backward_stub) +using hipdnn_convolution_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, int64_t, bool, bool, std::array); +DECLARE_DISPATCH(hipdnn_convolution_backward_fn, hipdnn_convolution_backward_stub) +using hipdnn_convolution_fn = at::Tensor(*)( + const at::Tensor&, const at::Tensor&, const std::optional&, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool); +DECLARE_DISPATCH(hipdnn_convolution_fn, hipdnn_convolution_stub) +using hipdnn_convolution_transpose_backward_fn = std::tuple(*)( + const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, + at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool, std::array); +DECLARE_DISPATCH(hipdnn_convolution_transpose_backward_fn, hipdnn_convolution_transpose_backward_stub) +using hipdnn_convolution_transpose_fn = at::Tensor(*)( + const at::Tensor&, const at::Tensor&, const std::optional&, + at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, int64_t, bool, bool); +DECLARE_DISPATCH(hipdnn_convolution_transpose_fn, hipdnn_convolution_transpose_stub) + using mkldnn_convolution_backward_fn = std::tuple(*)( const at::Tensor&, const at::Tensor&, const at::Tensor&, at::IntArrayRef, at::IntArrayRef, at::IntArrayRef, int64_t, std::array); @@ -117,6 +134,8 @@ enum class ConvBackend { Xnnpack2d, Mps, MpsTranspose, + Hipdnn, + HipdnnTranspose, }; // Overload for selecting the convolution backend from the full set of convolution inputs. @@ -398,6 +417,35 @@ inline at::MemoryFormat miopen_conv_suggest_memory_format(const at::Tensor& inpu return at::MemoryFormat::Contiguous; } +inline at::MemoryFormat hipdnn_conv_suggest_memory_format(const at::Tensor& input, const at::Tensor& weight) { + if (input.scalar_type() == at::kDouble || + weight.scalar_type() == at::kDouble) { + return at::MemoryFormat::Contiguous; + } + + auto input_memory_format = input.suggest_memory_format(); + auto weight_memory_format = weight.suggest_memory_format(); + auto weight_ndim = weight.ndimension(); + + bool can_use_channels_last_2d = (weight_ndim == 4) && ( + (input_memory_format == at::MemoryFormat::ChannelsLast) || + (weight_memory_format == at::MemoryFormat::ChannelsLast) + ); + if (can_use_channels_last_2d) { + return at::MemoryFormat::ChannelsLast; + } + + bool can_use_channels_last_3d = (weight_ndim == 5) && ( + (input_memory_format == at::MemoryFormat::ChannelsLast3d) || + (weight_memory_format == at::MemoryFormat::ChannelsLast3d) + ); + if (can_use_channels_last_3d) { + return at::MemoryFormat::ChannelsLast3d; + } + + return at::MemoryFormat::Contiguous; +} + // deprecated, but to remove would be BC-breaking inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) { return miopen_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous; diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 51d83ba16779e..a5297eab5da17 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -492,6 +492,15 @@ struct ConvParams { && !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1 ; } + bool use_hipdnn(const at::Tensor& input, const at::Tensor& weight) const { + if (!at::globalContext().userEnabledHipdnn()) return false; + if (!detail::getCUDAHooks().compiledWithHipDNN()) return false; + if (!input.is_cuda()) return false; + auto dtype = input.scalar_type(); + if (dtype != at::kFloat && dtype != at::kHalf && dtype != at::kBFloat16) return false; + if (input.dim() < 4 || input.dim() > 5) return false; + return true; + } bool use_mkldnn(const at::Tensor& input, const at::Tensor& weight) const { #if AT_MKLDNN_ENABLED() if (!at::globalContext().userEnabledMkldnn()) { @@ -623,6 +632,10 @@ DEFINE_DISPATCH(convolution_depthwise3x3_winograd_stub); DEFINE_DISPATCH(miopen_convolution_backward_stub); DEFINE_DISPATCH(miopen_convolution_transpose_backward_stub); DEFINE_DISPATCH(miopen_depthwise_convolution_backward_stub); +DEFINE_DISPATCH(hipdnn_convolution_backward_stub); +DEFINE_DISPATCH(hipdnn_convolution_transpose_backward_stub); +DEFINE_DISPATCH(hipdnn_convolution_stub); +DEFINE_DISPATCH(hipdnn_convolution_transpose_stub); DEFINE_DISPATCH(mkldnn_convolution_backward_stub); DEFINE_DISPATCH(mkldnn_convolution_transpose_stub); DEFINE_DISPATCH(mkldnn_convolution_transpose_backward_stub); @@ -636,6 +649,10 @@ REGISTER_NO_CPU_DISPATCH(cudnn_convolution_transpose_backward_stub) REGISTER_NO_CPU_DISPATCH(miopen_convolution_backward_stub) REGISTER_NO_CPU_DISPATCH(miopen_convolution_transpose_backward_stub) REGISTER_NO_CPU_DISPATCH(miopen_depthwise_convolution_backward_stub) +REGISTER_NO_CPU_DISPATCH(hipdnn_convolution_backward_stub) +REGISTER_NO_CPU_DISPATCH(hipdnn_convolution_transpose_backward_stub) +REGISTER_NO_CPU_DISPATCH(hipdnn_convolution_stub) +REGISTER_NO_CPU_DISPATCH(hipdnn_convolution_transpose_stub) template static std::ostream& operator<<(std::ostream & out, const ConvParams& params) { @@ -1275,6 +1292,8 @@ static ConvBackend _select_conv_backend( if (params.is_depthwise(input, weight)) { if (params.use_cudnn_depthwise(input, weight)) { return ConvBackend::Cudnn; + } else if (params.use_hipdnn(input, weight)) { + return ConvBackend::Hipdnn; } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) { return ConvBackend::MiopenDepthwise; } else { @@ -1292,6 +1311,12 @@ static ConvBackend _select_conv_backend( } else { return ConvBackend::Cudnn; } + } else if (params.use_hipdnn(input, weight)) { + if (params.transposed) { + return ConvBackend::HipdnnTranspose; + } else { + return ConvBackend::Hipdnn; + } } else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) { if (params.transposed) { return ConvBackend::MiopenTranspose; @@ -1489,6 +1514,10 @@ static inline at::MemoryFormat determine_backend_memory_format( backend_memory_format = miopen_conv_suggest_memory_format(input, weight); } break; + case ConvBackend::Hipdnn: + case ConvBackend::HipdnnTranspose: + backend_memory_format = hipdnn_conv_suggest_memory_format(input, weight); + break; case ConvBackend::Mkldnn: case ConvBackend::MkldnnTranspose: if (mkldnn_conv_use_channels_last(input, weight)) { @@ -1629,6 +1658,20 @@ at::Tensor _convolution( output = output.view(calc_output_size(input, weight, params)); break; } + case ConvBackend::Hipdnn: + check_input_same_type_as_parameters(input, weight, bias); + output = hipdnn_convolution_stub( + input.device().type(), + input, weight, bias, params.padding, params.stride, + params.dilation, params.groups, params.benchmark, params.deterministic); + break; + case ConvBackend::HipdnnTranspose: + check_input_same_type_as_parameters(input, weight, bias); + output = hipdnn_convolution_transpose_stub( + input.device().type(), + input, weight, bias, params.padding, params.output_padding, + params.stride, params.dilation, params.groups, params.benchmark, params.deterministic); + break; case ConvBackend::Miopen: check_input_same_type_as_parameters(input, weight, bias); output = at::miopen_convolution( @@ -2197,6 +2240,22 @@ std::tuple convolution_backward( TORCH_INTERNAL_ASSERT(false, "Mkldnn backend was selected in PyTorch compiled without mkldnn support"); #endif break; + case ConvBackend::Hipdnn: + check_input_same_type_as_parameters(input, weight); + std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = + hipdnn_convolution_backward_stub( + input.device().type(), + input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.stride, + params.dilation, params.groups, params.benchmark, params.deterministic, output_mask); + break; + case ConvBackend::HipdnnTranspose: + check_input_same_type_as_parameters(input, weight); + std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = + hipdnn_convolution_transpose_backward_stub( + input.device().type(), + input.contiguous(backend_memory_format), grad_output, weight, params.padding, params.output_padding, + params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, output_mask); + break; case ConvBackend::Miopen: check_input_same_type_as_parameters(input, weight); std::tie(backend_grad_input, backend_grad_weight, backend_grad_bias) = diff --git a/aten/src/ATen/native/hipdnn/Conv_hipdnn.cpp b/aten/src/ATen/native/hipdnn/Conv_hipdnn.cpp new file mode 100644 index 0000000000000..2ee74e52bf250 --- /dev/null +++ b/aten/src/ATen/native/hipdnn/Conv_hipdnn.cpp @@ -0,0 +1,714 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + +#include + +#if !AT_ROCM_ENABLED() || !defined(USE_HIPDNN) + +// No forward stubs needed: dispatch stubs (REGISTER_NO_CPU_DISPATCH) handle +// the non-CUDA case, and backend selection (use_hipdnn()) prevents dispatch +// to hipDNN when not compiled with hipDNN support. + +#else // AT_ROCM_ENABLED && USE_HIPDNN + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at::native { + +// --------------------------------------------------------------------------- +// Cache key: captures everything that determines graph topology +// --------------------------------------------------------------------------- +constexpr int hipdnn_max_dim = 3; + +struct HipdnnConvParams { + c10::DeviceIndex device_id; + hipdnn_frontend::DataType dataType; + int input_size[2 + hipdnn_max_dim]; + uint8_t input_dim; + at::MemoryFormat memory_format; + int weight_size[2 + hipdnn_max_dim]; + int output_size[2 + hipdnn_max_dim]; // dgrad/wgrad: disambiguates output_padding + int padding[hipdnn_max_dim]; + int stride[hipdnn_max_dim]; + int dilation[hipdnn_max_dim]; + int64_t groups; + bool has_bias; + int operation; // 0=fprop, 1=dgrad, 2=wgrad +}; + +static void setHipdnnConvParams( + HipdnnConvParams* params, + const Tensor& input, + const Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool has_bias, + at::MemoryFormat memory_format, + int operation, + IntArrayRef output_size = {}) { + memset(params, 0, sizeof(*params)); + params->device_id = input.device().index(); + params->dataType = getHipdnnDataType(input); + params->input_dim = static_cast(input.dim()); + params->memory_format = memory_format; + params->groups = groups; + params->has_bias = has_bias; + params->operation = operation; + for (int i = 0; i < input.dim(); i++) { + params->input_size[i] = static_cast(input.size(i)); + } + for (int i = 0; i < weight.dim(); i++) { + params->weight_size[i] = static_cast(weight.size(i)); + } + for (size_t i = 0; i < output_size.size(); i++) { + params->output_size[i] = static_cast(output_size[i]); + } + int spatial_dims = input.dim() - 2; + for (int i = 0; i < spatial_dims; i++) { + params->padding[i] = static_cast(padding[i]); + params->stride[i] = static_cast(stride[i]); + params->dilation[i] = static_cast(dilation[i]); + } +} + +// --------------------------------------------------------------------------- +// Cached graph value +// --------------------------------------------------------------------------- +struct HipdnnConvCachedGraph { + std::shared_ptr graph; + int64_t workspace_size; +}; + +// --------------------------------------------------------------------------- +// Thread-local LRU cache (same pattern as cuDNN v8 Conv_v8.cpp) +// --------------------------------------------------------------------------- +// Returns the LRU cache limit for convolution graphs. +// Special values (matching cuDNN v8): +// 0 = unlimited (no eviction) +// negative = caching disabled +static int getHipdnnConvCacheLimit() { + static int limit = []{ + constexpr int DEFAULT_LIMIT = 10000; + const auto val = c10::utils::get_env("TORCH_HIPDNN_CONV_LRU_CACHE_LIMIT"); + if (!val) { + return DEFAULT_LIMIT; + } + try { + return std::stoi(val.value()); + } catch (std::invalid_argument const&) { + TORCH_WARN( + "invalid TORCH_HIPDNN_CONV_LRU_CACHE_LIMIT,", + " using default LRU cache limit of ", + DEFAULT_LIMIT, + " entries."); + } catch (std::out_of_range const&) { + TORCH_WARN( + "invalid TORCH_HIPDNN_CONV_LRU_CACHE_LIMIT,", + " using default LRU cache limit of ", + DEFAULT_LIMIT, + " entries."); + } + return DEFAULT_LIMIT; + }(); + return limit; +} + +// LRU cache for hipDNN convolution graph lookups. Keyed by convolution +// parameters (POD struct), valued by compiled graph. When we add hipDNN +// batch-norm support, this can move to a shared header. +template +struct ParamsLRUCache { + using KeyWrapper = ParamsWrapper; + + int cache_limit; + std::list cache_order; + std::unordered_map< + KeyWrapper, + std::pair::iterator>, + ParamsWrapperHash> cache; + + explicit ParamsLRUCache(int limit) : cache_limit(limit) {} + + ValueType* find(const KeyType& key) { + if (cache_limit < 0) return nullptr; + KeyWrapper wrapped; + wrapped.pod = key; + auto it = cache.find(wrapped); + if (it == cache.end()) return nullptr; + if (cache_limit) { + cache_order.splice(cache_order.begin(), cache_order, it->second.second); + } + return &(it->second.first); + } + + void update(const KeyType& key, ValueType entry) { + if (cache_limit < 0) return; + KeyWrapper wrapped; + wrapped.pod = key; + auto it = cache.find(wrapped); + if (it == cache.end()) { + if (cache_limit == 0) { + cache.emplace(wrapped, std::make_pair(std::move(entry), cache_order.end())); + } else { + if (static_cast(cache.size()) >= cache_limit) { + auto count = cache.erase(cache_order.back()); + TORCH_INTERNAL_ASSERT(count == 1, "LRU cache eviction failed to erase key"); + cache_order.pop_back(); + } + cache_order.emplace_front(wrapped); + cache.emplace(wrapped, std::make_pair(std::move(entry), cache_order.begin())); + } + } else { + it->second.first = std::move(entry); + if (cache_limit) { + cache_order.splice(cache_order.begin(), cache_order, it->second.second); + } + } + } +}; + +using HipdnnConvCache = ParamsLRUCache; + +static HipdnnConvCache* getHipdnnConvCache() { + static thread_local auto* cache = new HipdnnConvCache(getHipdnnConvCacheLimit()); + return cache; +} + +// --------------------------------------------------------------------------- +// Deterministic UID assignment for graph tensors +// --------------------------------------------------------------------------- +enum HipdnnConvUid : int64_t { + // Forward (fprop) + UID_INPUT = 1, + UID_WEIGHT = 2, + UID_OUTPUT = 3, + UID_BIAS = 4, + // Backward data (dgrad) — aliases + UID_DGRAD_GRAD_OUTPUT = UID_INPUT, + UID_DGRAD_WEIGHT = UID_WEIGHT, + UID_DGRAD_GRAD_INPUT = UID_OUTPUT, + // Backward weight (wgrad) — aliases + UID_WGRAD_GRAD_OUTPUT = UID_INPUT, + UID_WGRAD_INPUT = UID_WEIGHT, + UID_WGRAD_GRAD_WEIGHT = UID_OUTPUT, +}; + +// --------------------------------------------------------------------------- +// Graph builders +// +// Note: groups are not explicitly passed to graph builders. HipDNN infers +// groupCount from tensor dimensions (input_channels / weight_channels_per_group). +// PyTorch provides correctly-shaped weight tensors [C_out, C_in/groups, kH, kW]. +// --------------------------------------------------------------------------- +static HipdnnConvCachedGraph buildConvFpropGraph( + hipdnnHandle_t handle, + const Tensor& input, + const Tensor& weight, + const Tensor& output, + const Tensor* bias, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation) { + + auto inputType = getHipdnnDataType(input); + auto graph = std::make_shared(); + graph->set_io_data_type(inputType) + .set_intermediate_data_type(hipdnn_frontend::DataType::FLOAT) + .set_compute_data_type(hipdnn_frontend::DataType::FLOAT); + + auto x_attr = createTensorAttributes(input); + x_attr->set_uid(UID_INPUT); + auto w_attr = createTensorAttributes(weight); + w_attr->set_uid(UID_WEIGHT); + + hipdnn_frontend::graph::ConvFpropAttributes conv_attrs; + conv_attrs.set_padding(std::vector(padding.begin(), padding.end())); + conv_attrs.set_stride(std::vector(stride.begin(), stride.end())); + conv_attrs.set_dilation(std::vector(dilation.begin(), dilation.end())); + + auto conv_out = graph->conv_fprop(x_attr, w_attr, conv_attrs); + + if (bias) { + conv_out->set_dim(output.sizes().vec()); + conv_out->set_stride(output.strides().vec()); + + auto bias_reshaped = reshape_bias(input.dim(), *bias); + auto b_attr = createTensorAttributes(bias_reshaped); + b_attr->set_uid(UID_BIAS); + + hipdnn_frontend::graph::PointwiseAttributes add_attrs; + add_attrs.set_mode(hipdnn_frontend::PointwiseMode::ADD); + add_attrs.set_compute_data_type(inputType); + + auto y_attr = graph->pointwise(conv_out, b_attr, add_attrs); + y_attr->set_output(true).set_uid(UID_OUTPUT); + } else { + conv_out->set_output(true).set_uid(UID_OUTPUT); + } + + HIPDNN_FE_CHECK(graph->build(handle)); + + int64_t ws = 0; + HIPDNN_FE_CHECK(graph->get_workspace_size(ws)); + + return {std::move(graph), ws}; +} + +static HipdnnConvCachedGraph buildConvDgradGraph( + hipdnnHandle_t handle, + const Tensor& grad_output, + const Tensor& weight, + const Tensor& output, + const Tensor* bias, + IntArrayRef input_size, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation) { + + auto inputType = getHipdnnDataType(grad_output); + auto graph = std::make_shared(); + graph->set_io_data_type(inputType) + .set_intermediate_data_type(hipdnn_frontend::DataType::FLOAT) + .set_compute_data_type(hipdnn_frontend::DataType::FLOAT); + + auto dy_attr = createTensorAttributes(grad_output); + dy_attr->set_uid(UID_INPUT); + auto w_attr = createTensorAttributes(weight); + w_attr->set_uid(UID_WEIGHT); + + hipdnn_frontend::graph::ConvDgradAttributes conv_attrs; + conv_attrs.set_padding(std::vector(padding.begin(), padding.end())); + conv_attrs.set_stride(std::vector(stride.begin(), stride.end())); + conv_attrs.set_dilation(std::vector(dilation.begin(), dilation.end())); + + auto dx_attr = graph->conv_dgrad(dy_attr, w_attr, conv_attrs); + dx_attr->set_dim(std::vector(input_size.begin(), input_size.end())); + + if (bias) { + dx_attr->set_stride(output.strides().vec()); + + auto bias_reshaped = reshape_bias(grad_output.dim(), *bias); + auto b_attr = createTensorAttributes(bias_reshaped); + b_attr->set_uid(UID_BIAS); + + hipdnn_frontend::graph::PointwiseAttributes add_attrs; + add_attrs.set_mode(hipdnn_frontend::PointwiseMode::ADD); + add_attrs.set_compute_data_type(inputType); + + auto y_attr = graph->pointwise(dx_attr, b_attr, add_attrs); + y_attr->set_output(true).set_uid(UID_OUTPUT); + } else { + dx_attr->set_output(true).set_uid(UID_OUTPUT); + } + + HIPDNN_FE_CHECK(graph->build(handle)); + + int64_t ws = 0; + HIPDNN_FE_CHECK(graph->get_workspace_size(ws)); + + return {std::move(graph), ws}; +} + +static HipdnnConvCachedGraph buildConvWgradGraph( + hipdnnHandle_t handle, + const Tensor& grad_output, + const Tensor& input, + IntArrayRef weight_size, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation) { + + auto inputType = getHipdnnDataType(input); + auto graph = std::make_shared(); + // No set_intermediate_data_type needed: single-op graph has no virtual tensors. + graph->set_io_data_type(inputType) + .set_compute_data_type(hipdnn_frontend::DataType::FLOAT); + + auto dy_attr = createTensorAttributes(grad_output); + dy_attr->set_uid(UID_INPUT); + auto x_attr = createTensorAttributes(input); + x_attr->set_uid(UID_WEIGHT); + + hipdnn_frontend::graph::ConvWgradAttributes conv_attrs; + conv_attrs.set_padding(std::vector(padding.begin(), padding.end())); + conv_attrs.set_stride(std::vector(stride.begin(), stride.end())); + conv_attrs.set_dilation(std::vector(dilation.begin(), dilation.end())); + + auto dw_attr = graph->conv_wgrad(dy_attr, x_attr, conv_attrs); + dw_attr->set_dim(std::vector(weight_size.begin(), weight_size.end())); + dw_attr->set_output(true).set_uid(UID_OUTPUT); + + HIPDNN_FE_CHECK(graph->build(handle)); + + int64_t ws = 0; + HIPDNN_FE_CHECK(graph->get_workspace_size(ws)); + + return {std::move(graph), ws}; +} + +// --------------------------------------------------------------------------- +// Graph execution helpers (cache-check-then-build-and-execute) +// --------------------------------------------------------------------------- +static void runHipdnnConvFprop( + const Tensor& input, + const Tensor& weight, + const Tensor& output, + const Tensor* bias, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + at::MemoryFormat memory_format, + bool benchmark, + bool deterministic) { + + TORCH_CHECK( + !deterministic, + "hipdnn_convolution does not support deterministic mode yet. " + "hipDNN does not currently provide engine-level determinism guarantees."); + if (benchmark) { + TORCH_WARN_ONCE( + "hipdnn_convolution: benchmark mode is not supported yet and will be ignored. " + "hipDNN does not currently support algorithm search."); + } + + auto handle = getHipdnnHandle(); + auto* cache = getHipdnnConvCache(); + + bool has_bias = bias != nullptr; + HipdnnConvParams key; + setHipdnnConvParams(&key, input, weight, padding, stride, dilation, + groups, has_bias, memory_format, /*operation=*/0); + + auto* cached = cache->find(key); + if (!cached) { + auto entry = buildConvFpropGraph( + handle, input, weight, output, bias, padding, stride, dilation); + cache->update(key, std::move(entry)); + cached = cache->find(key); + } + + std::unordered_map variantPack; + variantPack[UID_INPUT] = input.data_ptr(); + variantPack[UID_WEIGHT] = weight.data_ptr(); + variantPack[UID_OUTPUT] = output.data_ptr(); + if (bias) { + variantPack[UID_BIAS] = bias->data_ptr(); + } + + // Workspace inherits device from input.options() + auto workspace = at::empty({cached->workspace_size}, input.options().dtype(at::kByte)); + HIPDNN_FE_CHECK(cached->graph->execute(handle, variantPack, workspace.data_ptr())); +} + +static void runHipdnnConvDgrad( + const Tensor& grad_output, + const Tensor& weight, + const Tensor& grad_input, + const Tensor* bias, + IntArrayRef input_size, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + at::MemoryFormat memory_format, + bool benchmark, + bool deterministic) { + + TORCH_CHECK( + !deterministic, + "hipdnn_convolution does not support deterministic mode yet. " + "hipDNN does not currently provide engine-level determinism guarantees."); + if (benchmark) { + TORCH_WARN_ONCE( + "hipdnn_convolution: benchmark mode is not supported yet and will be ignored. " + "hipDNN does not currently support algorithm search."); + } + + auto handle = getHipdnnHandle(); + auto* cache = getHipdnnConvCache(); + + bool has_bias = bias != nullptr; + HipdnnConvParams key; + // For dgrad, use grad_output as the "input" for the cache key. + // input_size disambiguates cases with different output_padding. + setHipdnnConvParams(&key, grad_output, weight, padding, stride, dilation, + groups, has_bias, memory_format, /*operation=*/1, + input_size); + + auto* cached = cache->find(key); + if (!cached) { + auto entry = buildConvDgradGraph(handle, grad_output, weight, grad_input, + bias, input_size, padding, stride, dilation); + cache->update(key, std::move(entry)); + cached = cache->find(key); + } + + std::unordered_map variantPack; + variantPack[UID_DGRAD_GRAD_OUTPUT] = grad_output.data_ptr(); + variantPack[UID_DGRAD_WEIGHT] = weight.data_ptr(); + variantPack[UID_DGRAD_GRAD_INPUT] = grad_input.data_ptr(); + if (bias) { + variantPack[UID_BIAS] = bias->data_ptr(); + } + + // Workspace inherits device from grad_output.options() + auto workspace = at::empty({cached->workspace_size}, grad_output.options().dtype(at::kByte)); + HIPDNN_FE_CHECK(cached->graph->execute(handle, variantPack, workspace.data_ptr())); +} + +static void runHipdnnConvWgrad( + const Tensor& grad_output, + const Tensor& input, + const Tensor& grad_weight, + IntArrayRef weight_size, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + at::MemoryFormat memory_format, + bool benchmark, + bool deterministic) { + + TORCH_CHECK( + !deterministic, + "hipdnn_convolution does not support deterministic mode yet. " + "hipDNN does not currently provide engine-level determinism guarantees."); + if (benchmark) { + TORCH_WARN_ONCE( + "hipdnn_convolution: benchmark mode is not supported yet and will be ignored. " + "hipDNN does not currently support algorithm search."); + } + + auto handle = getHipdnnHandle(); + auto* cache = getHipdnnConvCache(); + + HipdnnConvParams key; + setHipdnnConvParams(&key, grad_output, input, padding, stride, dilation, + groups, /*has_bias=*/false, memory_format, /*operation=*/2, + weight_size); + + auto* cached = cache->find(key); + if (!cached) { + auto entry = buildConvWgradGraph(handle, grad_output, input, weight_size, + padding, stride, dilation); + cache->update(key, std::move(entry)); + cached = cache->find(key); + } + + std::unordered_map variantPack; + variantPack[UID_WGRAD_GRAD_OUTPUT] = grad_output.data_ptr(); + variantPack[UID_WGRAD_INPUT] = input.data_ptr(); + variantPack[UID_WGRAD_GRAD_WEIGHT] = grad_weight.data_ptr(); + + // Workspace inherits device from grad_output.options() + auto workspace = at::empty({cached->workspace_size}, grad_output.options().dtype(at::kByte)); + HIPDNN_FE_CHECK(cached->graph->execute(handle, variantPack, workspace.data_ptr())); +} + +// --------------------------------------------------------------------------- +// Public entry points +// --------------------------------------------------------------------------- +Tensor hipdnn_convolution( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { + + TensorArg input{input_t, "input", 1}; + TensorArg weight{weight_t, "weight", 2}; + CheckedFrom c = "hipdnn_convolution"; + checkAllSameType(c, {input, weight}); + checkAllSameGPU(c, {input, weight}); + + auto memory_format = hipdnn_conv_suggest_memory_format(input_t, weight_t); + auto input_c = input_t.contiguous(memory_format); + auto weight_c = weight_t.contiguous(memory_format); + + auto output_size = conv_output_size( + input_c.sizes(), weight_c.sizes(), padding, stride, dilation); + auto output = at::empty(output_size, input_c.options(), memory_format); + + bool has_bias = bias_opt.has_value() && bias_opt->defined(); + const Tensor* bias_ptr = has_bias ? &(*bias_opt) : nullptr; + runHipdnnConvFprop(input_c, weight_c, output, bias_ptr, + padding, stride, dilation, groups, memory_format, + benchmark, deterministic); + + return output; +} + +Tensor hipdnn_convolution_transpose( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_opt, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic) { + + TensorArg input{input_t, "input", 1}; + TensorArg weight{weight_t, "weight", 2}; + CheckedFrom c = "hipdnn_convolution_transpose"; + checkAllSameType(c, {input, weight}); + checkAllSameGPU(c, {input, weight}); + + auto memory_format = hipdnn_conv_suggest_memory_format(input_t, weight_t); + auto input_c = input_t.contiguous(memory_format); + auto weight_c = weight_t.contiguous(memory_format); + + auto trans_output_size = conv_input_size( + input_c.sizes(), weight_c.sizes(), padding, output_padding, stride, dilation, groups); + auto output = at::empty(trans_output_size, input_c.options(), memory_format); + + bool has_bias = bias_opt.has_value() && bias_opt->defined(); + const Tensor* bias_ptr = has_bias ? &(*bias_opt) : nullptr; + runHipdnnConvDgrad(input_c, weight_c, output, bias_ptr, + trans_output_size, padding, stride, dilation, + groups, memory_format, benchmark, deterministic); + + return output; +} + +// --------------------------------------------------------------------------- +// Backward +// --------------------------------------------------------------------------- +std::tuple hipdnn_convolution_backward( + const Tensor& input, + const Tensor& grad_output_t, + const Tensor& weight, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + std::array output_mask) { + + auto memory_format = hipdnn_conv_suggest_memory_format(input, weight); + auto grad_output = grad_output_t.contiguous(memory_format); + auto input_c = input.contiguous(memory_format); + auto weight_c = weight.contiguous(memory_format); + + Tensor grad_input, grad_weight, grad_bias; + + if (output_mask[0]) { + grad_input = at::empty(input_c.sizes(), input_c.options(), memory_format); + runHipdnnConvDgrad(grad_output, weight_c, grad_input, /*bias=*/nullptr, + input_c.sizes(), padding, stride, dilation, + groups, memory_format, benchmark, deterministic); + } + + if (output_mask[1]) { + grad_weight = at::empty(weight_c.sizes(), weight_c.options(), memory_format); + runHipdnnConvWgrad(grad_output, input_c, grad_weight, weight_c.sizes(), + padding, stride, dilation, groups, memory_format, + benchmark, deterministic); + } + + if (output_mask[2]) { + std::vector reduce_dims; + reduce_dims.push_back(0); + for (int64_t i = 2; i < grad_output.dim(); i++) { + reduce_dims.push_back(i); + } + grad_bias = grad_output.sum(reduce_dims); + } + + return std::make_tuple( + std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); +} + +std::tuple hipdnn_convolution_transpose_backward( + const Tensor& input, + const Tensor& grad_output_t, + const Tensor& weight, + IntArrayRef padding, + IntArrayRef output_padding, + IntArrayRef stride, + IntArrayRef dilation, + int64_t groups, + bool benchmark, + bool deterministic, + std::array output_mask) { + + auto memory_format = hipdnn_conv_suggest_memory_format(input, weight); + auto grad_output = grad_output_t.contiguous(memory_format); + auto input_c = input.contiguous(memory_format); + auto weight_c = weight.contiguous(memory_format); + + Tensor grad_input, grad_weight, grad_bias; + + if (output_mask[0]) { + // Transpose backward-input = fprop + grad_input = at::empty(input_c.sizes(), input_c.options(), memory_format); + runHipdnnConvFprop(grad_output, weight_c, grad_input, /*bias=*/nullptr, + padding, stride, dilation, groups, memory_format, + benchmark, deterministic); + } + + if (output_mask[1]) { + // Transpose backward-weight = wgrad + grad_weight = at::empty(weight_c.sizes(), weight_c.options(), memory_format); + runHipdnnConvWgrad(input_c, grad_output, grad_weight, weight_c.sizes(), + padding, stride, dilation, groups, memory_format, + benchmark, deterministic); + } + + if (output_mask[2]) { + std::vector reduce_dims; + reduce_dims.push_back(0); + for (int64_t i = 2; i < grad_output.dim(); i++) { + reduce_dims.push_back(i); + } + grad_bias = grad_output.sum(reduce_dims); + } + + return std::make_tuple( + std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); +} + +// --------------------------------------------------------------------------- +// Dispatch stub registration +// --------------------------------------------------------------------------- +REGISTER_CUDA_DISPATCH(hipdnn_convolution_stub, &hipdnn_convolution) +REGISTER_CUDA_DISPATCH(hipdnn_convolution_transpose_stub, &hipdnn_convolution_transpose) +REGISTER_CUDA_DISPATCH(hipdnn_convolution_backward_stub, &hipdnn_convolution_backward) +REGISTER_CUDA_DISPATCH(hipdnn_convolution_transpose_backward_stub, &hipdnn_convolution_transpose_backward) + +} // namespace at::native + +#endif diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index bb393f63fc38f..bcf1b95bd26a4 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -39,6 +39,7 @@ def _get_cudnn_version(): skipCPUIfNoMkldnn, skipCUDAIfMiopen, skipCUDAIfNoCudnn, + skipCUDAIfNoHipdnn, skipCUDAIfNoMiopen, skipCUDAIfRocm, skipCUDAIfRocmHipBlasltVersionLessThan, @@ -52,6 +53,7 @@ def _get_cudnn_version(): ) from torch.testing._internal.common_nn import _test_module_empty_input, NNTestCase from torch.testing._internal.common_utils import ( + DeterministicGuard, download_file, dtype2prec_DONTUSE, gradcheck, @@ -4348,6 +4350,210 @@ def test_depthwise_conv_64bit_indexing(self, device): y = c.to(device=device)(x.to(device=device)) self.assertEqual(yref, y, atol=5e-3, rtol=1e-4) + # === hipDNN convolution tests === + + def _hipdnn_compare_conv( + self, + device, + x_shape, + w_shape, + bias, + stride, + padding, + dilation, + groups, + dtype, + transposed=False, + output_padding=0, + memory_format=torch.contiguous_format, + atol=1e-4, + rtol=1e-4, + ): + """Compares float32 cpu reference output to hipdnn (gpu) output.""" + + x_gpu = torch.randn(*x_shape, dtype=dtype, device=device) + if memory_format != torch.contiguous_format: + x_gpu = x_gpu.contiguous(memory_format=memory_format) + w_gpu = torch.randn(*w_shape, dtype=dtype, device=device) + b_gpu = ( + torch.randn( + w_shape[1] if transposed else w_shape[0], dtype=dtype, device=device + ) + if bias + else None + ) + + x_cpu = x_gpu.float().cpu().requires_grad_(True) + w_cpu = w_gpu.float().cpu().requires_grad_(True) + b_cpu = b_gpu.float().cpu().requires_grad_(True) if bias else None + + conv_fn = F.conv_transpose2d if transposed else F.conv2d + kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) + if transposed: + kwargs["output_padding"] = output_padding + + out_cpu = conv_fn(x_cpu, w_cpu, b_cpu, **kwargs) + out_cpu.sum().backward() + + x_gpu = x_gpu.detach().requires_grad_(True) + w_gpu = w_gpu.detach().requires_grad_(True) + b_gpu = b_gpu.detach().requires_grad_(True) if bias else None + + out_gpu = conv_fn(x_gpu, w_gpu, b_gpu, **kwargs) + out_gpu.sum().backward() + + self.assertEqual(out_cpu.float(), out_gpu.float().cpu(), atol=atol, rtol=rtol) + self.assertEqual( + x_cpu.grad.float(), x_gpu.grad.float().cpu(), atol=atol, rtol=rtol + ) + self.assertEqual( + w_cpu.grad.float(), w_gpu.grad.float().cpu(), atol=atol, rtol=rtol + ) + if bias: + self.assertEqual( + b_cpu.grad.float(), b_gpu.grad.float().cpu(), atol=atol, rtol=rtol + ) + + @onlyCUDA + @skipCUDAIfNoHipdnn + @dtypes(torch.float32, torch.float16, torch.bfloat16) + @parametrize_test("has_bias", [False, True]) + @parametrize_test("transposed", [False, True]) + @torch.backends.hipdnn.flags(enabled=True) + def test_conv2d_hipdnn(self, device, dtype, has_bias, transposed): + # Condition on rocm sdk version when hipdnn ships with conv+bias fusion support by default. + if has_bias: + self.skipTest( + "No default plugin for hipdnn supports conv + bias fusion yet." + ) + + if dtype == torch.float32: + C_in, C_out, rtol, atol = 16, 32, 1e-4, 1e-4 + elif dtype == torch.float16: + C_in, C_out, rtol, atol = 8, 16, 5e-2, 5e-2 + else: + C_in, C_out, rtol, atol = 4, 8, 1e-1, 1e-1 + + if transposed: + self._hipdnn_compare_conv( + device, + (2, C_out, 16, 16), + (C_out, C_in, 3, 3), + bias=has_bias, + stride=2, + padding=1, + dilation=1, + groups=1, + dtype=dtype, + transposed=True, + output_padding=1, + atol=atol, + rtol=rtol, + ) + else: + self._hipdnn_compare_conv( + device, + (2, C_in, 32, 32), + (C_out, C_in, 3, 3), + bias=has_bias, + stride=1, + padding=1, + dilation=1, + groups=1, + dtype=dtype, + atol=atol, + rtol=rtol, + ) + + @onlyCUDA + @skipCUDAIfNoHipdnn + @torch.backends.hipdnn.flags(enabled=True) + @parametrize_test( + "config", + [ + subtest( + ((2, 64, 32, 32), (128, 64, 3, 3), 2, 1, 1, 1), + name="stride2", + ), + subtest( + ((2, 64, 32, 32), (128, 64, 3, 3), 1, 2, 2, 1), + name="dilation2", + ), + subtest( + ((2, 128, 32, 32), (128, 32, 3, 3), 1, 1, 1, 4), + name="grouped", + ), + subtest( + ((2, 128, 32, 32), (128, 1, 3, 3), 1, 1, 1, 128), + name="depthwise", + ), + subtest( + ((1, 3, 224, 224), (64, 3, 7, 7), 2, 3, 1, 1), + name="resnet_first", + ), + ], + ) + def test_conv2d_hipdnn_configs(self, device, config): + x_shape, w_shape, stride, padding, dilation, groups = config + self._hipdnn_compare_conv( + device, + x_shape, + w_shape, + bias=False, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + dtype=torch.float32, + ) + + @onlyCUDA + @skipCUDAIfNoHipdnn + @torch.backends.hipdnn.flags(enabled=True) + def test_conv2d_hipdnn_channels_last(self, device): + self._hipdnn_compare_conv( + device, + (2, 64, 32, 32), + (128, 64, 3, 3), + bias=False, + stride=1, + padding=1, + dilation=1, + groups=1, + dtype=torch.float32, + memory_format=torch.channels_last, + ) + + @onlyCUDA + @skipCUDAIfNoHipdnn + @torch.backends.hipdnn.flags(enabled=True) + def test_conv2d_hipdnn_deterministic_error(self, device): + x = torch.randn(2, 64, 32, 32, device=device) + w = torch.randn(128, 64, 3, 3, device=device) + with DeterministicGuard(True): + with self.assertRaisesRegex( + RuntimeError, "hipdnn_convolution does not support deterministic" + ): + F.conv2d(x, w, padding=1) + + @onlyCUDA + @skipCUDAIfNoHipdnn + def test_conv2d_hipdnn_backend_selection(self, device): + x = torch.randn(2, 64, 32, 32, device=device) + w = torch.randn(128, 64, 3, 3, device=device) + inputs = [x, w, None, (1,) * 2, (1,) * 2, (1,) * 2, False, (0,) * 2, 1] + + with torch.backends.hipdnn.flags(enabled=True): + backend = torch._C._select_conv_backend(*inputs) + self.assertEqual(backend, torch._C._ConvBackend.Hipdnn) + + # Transposed + w_t = torch.randn(64, 128, 3, 3, device=device) + inputs_t = [x, w_t, None, (1,) * 2, (0,) * 2, (1,) * 2, True, (0,) * 2, 1] + with torch.backends.hipdnn.flags(enabled=True): + backend_t = torch._C._select_conv_backend(*inputs_t) + self.assertEqual(backend_t, torch._C._ConvBackend.HipdnnTranspose) + instantiate_device_type_tests( TestConvolutionNNDeviceType, globals(), allow_mps=True, allow_xpu=True diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 4324bffc48941..3cdcae081b808 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2586,7 +2586,9 @@ Call this whenever a new thread is created in order to propagate values from "Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise) .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d) .value("Mps", at::native::ConvBackend::Mps) - .value("MpsTranspose,", at::native::ConvBackend::MpsTranspose); + .value("MpsTranspose,", at::native::ConvBackend::MpsTranspose) + .value("Hipdnn", at::native::ConvBackend::Hipdnn) + .value("HipdnnTranspose", at::native::ConvBackend::HipdnnTranspose); py_module.def( "_select_conv_backend", diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 8fc9f2cd9682f..33e65097642a0 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -550,6 +550,7 @@ class CUDATestBase(DeviceTypeTestBase): cudnn_version: ClassVar[Any] no_magma: ClassVar[bool] no_cudnn: ClassVar[bool] + no_hipdnn: ClassVar[bool] def has_cudnn(self): return not self.no_cudnn @@ -582,6 +583,9 @@ def setUpClass(cls): cls.no_cudnn = not torch.backends.cudnn.is_acceptable(t) cls.cudnn_version = None if cls.no_cudnn else torch.backends.cudnn.version() + # Determines if hipDNN is available + cls.no_hipdnn = not torch.backends.hipdnn.is_available() + # Acquires the current device as the primary (test) device cls.primary_device = f"cuda:{torch.cuda.current_device()}" @@ -2042,6 +2046,16 @@ def skipCUDAIfNoMiopen(fn): ) +def skipCUDAIfNoHipdnn(fn): + @wraps(fn) + def wrap_fn(self, *args, **kwargs): + if self.device_type == "cuda" and self.no_hipdnn: + raise unittest.SkipTest("hipDNN not available") + return fn(self, *args, **kwargs) + + return wrap_fn + + def skipLazy(fn): return skipLazyIf(True, "test doesn't work with lazy tensors")(fn) From f12ae4df094b6e3299eca57ca634df330db7763c Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Wed, 25 Mar 2026 13:35:04 -0700 Subject: [PATCH 03/10] [ROCm] Add hipdnn batchnorm support Adds support for selecting hipdnn as a backend for `batch_norm`. Unfortunately, batch norm is in a half-migrated state wrt. new dispatch stack. Consequently, this PR adds a new backend-specific frontend op. RFC: https://dev-discuss.pytorch.org/t/rfc-adding-a-batch-norm-backend-revisiting-dispatch-stack-issues/3327 Co-authored-by: Dmitry Nikolaev Signed-off-by: zjgarvey Assisted-by: Claude Opus 4.6 --- aten/src/ATen/functorch/BatchRulesNorm.cpp | 71 ++++ aten/src/ATen/native/Normalization.cpp | 33 ++ aten/src/ATen/native/Normalization.h | 1 + aten/src/ATen/native/cuda/Normalization.cu | 15 + .../ATen/native/hipdnn/BatchNorm_hipdnn.cpp | 396 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 10 + ...asDecompTest.test_has_decomposition.expect | 2 + .../check_forward_backward_compatibility.py | 2 + test/test_nn.py | 45 +- tools/autograd/derivatives.yaml | 10 + torch/_decomp/__init__.py | 1 + torch/_decomp/decompositions.py | 26 ++ torch/_decomp/decompositions_for_jvp.py | 1 + torch/csrc/Module.cpp | 3 +- 14 files changed, 606 insertions(+), 10 deletions(-) create mode 100644 aten/src/ATen/native/hipdnn/BatchNorm_hipdnn.cpp diff --git a/aten/src/ATen/functorch/BatchRulesNorm.cpp b/aten/src/ATen/functorch/BatchRulesNorm.cpp index affb6ce369f2b..d8bcf4fddd23e 100644 --- a/aten/src/ATen/functorch/BatchRulesNorm.cpp +++ b/aten/src/ATen/functorch/BatchRulesNorm.cpp @@ -815,6 +815,60 @@ struct MiopenBatchNormBackwardBatchRuleHelper { decltype(&fn),\ &fn>::apply) +template +struct HipdnnBatchNormBatchRuleHelper { + static std::tuple,Tensor, std::optional,Tensor, std::optional> apply( + const Tensor& input, std::optional input_bdim, + const Tensor& weight_opt, std::optional weight_bdim, + const std::optional& bias_opt, std::optional bias_bdim, + const std::optional& running_mean_opt, std::optional running_mean_bdim, + const std::optional& running_var_opt, std::optional running_var_bdim, + bool training, double momentum, double eps) { + return batch_norm_batch_rule( + input, input_bdim, weight_opt, weight_bdim, bias_opt, bias_bdim, + running_mean_opt, running_mean_bdim, running_var_opt, running_var_bdim, training, momentum, eps); + } +}; + +template +struct HipdnnBatchNormBackwardBatchRuleHelper { + static std::tuple apply( + const at::Tensor & input, + const at::Tensor & grad_out, + const at::Tensor & weight, + const std::optional & running_mean_opt, + const std::optional & running_var_opt, + const std::optional & save_mean_opt, + const std::optional & save_rstd_opt, + double eps) { + + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "HipdnnBatchNormBackwardBatchRuleHelper.apply"); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + int64_t cur_level = maybe_layer->layerId(); + + if (!areAnyBatchedAtLevel({input, grad_out, weight, running_mean_opt, + running_var_opt, save_mean_opt, save_rstd_opt}, cur_level)) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + return at::hipdnn_batch_norm_backward(input, grad_out, weight, + running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps); + } + + return batch_norm_backward_plumbing( + grad_out, input, weight, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, true, eps, {true, true, true}); + } +}; + +#define HIPDNN_BATCH_NORM_BATCH_RULE(fn) SINGLE_ARG(\ + HipdnnBatchNormBatchRuleHelper<\ + decltype(&ATEN_FN(fn)),\ + &ATEN_FN(fn)>::apply) + +#define HIPDNN_BATCH_NORM_BACKWARD_BATCH_RULE(fn) SINGLE_ARG(\ + HipdnnBatchNormBackwardBatchRuleHelper<\ + decltype(&fn),\ + &fn>::apply) + static std::tuple cudnn_batch_norm_backward_wrapper( const at::Tensor & grad_out, const at::Tensor & input, @@ -846,6 +900,21 @@ static std::tuple miopen_batch_norm_backward_w return at::miopen_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps); } +static std::tuple hipdnn_batch_norm_backward_wrapper( + const at::Tensor & grad_out, + const at::Tensor & input, + const at::Tensor& weight_opt, + const std::optional & running_mean_opt, + const std::optional & running_var_opt, + const std::optional & save_mean_opt, + const std::optional & save_rstd_opt, + bool training, + double eps, + std::array output_mask) { + TORCH_INTERNAL_ASSERT(!training); + return at::hipdnn_batch_norm_backward(input, grad_out, weight_opt, running_mean_opt, running_var_opt, save_mean_opt, save_rstd_opt, eps); + } + // NB: This is NOT good. In the ideal world, we do NOT want to convert the new legit op back into native_batch_norm // as native_batch_norm has a problematic schema--it promises it is functional when it is not. However, vmap doesn't // work with dynamo anyway so we gain some buffer room to do wrong things here. The (reasonable) hope is that we will @@ -866,11 +935,13 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(native_batch_norm, NATIVE_BATCH_NORM_BATCH_RULE(native_batch_norm)); VMAP_SUPPORT(cudnn_batch_norm, CUDNN_BATCH_NORM_BATCH_RULE(cudnn_batch_norm)); VMAP_SUPPORT(miopen_batch_norm, MIOPEN_BATCH_NORM_BATCH_RULE(miopen_batch_norm)); + VMAP_SUPPORT(hipdnn_batch_norm, HIPDNN_BATCH_NORM_BATCH_RULE(hipdnn_batch_norm)); m.impl("_native_batch_norm_legit", _native_batch_norm_legit_batch); m.impl("_native_batch_norm_legit.no_stats", _native_batch_norm_legit_no_stats_batch); m.impl("native_batch_norm_backward", NATIVE_BATCH_NORM_BACKWARD_BATCH_RULE(native_batch_norm_backward)); m.impl("cudnn_batch_norm_backward", CUDNN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::cudnn_batch_norm_backward_wrapper)); m.impl("miopen_batch_norm_backward", MIOPEN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::miopen_batch_norm_backward_wrapper)); + m.impl("hipdnn_batch_norm_backward", HIPDNN_BATCH_NORM_BACKWARD_BATCH_RULE(at::functorch::hipdnn_batch_norm_backward_wrapper)); m.impl("native_group_norm", native_group_norm_plumbing); m.impl("native_group_norm_backward", native_group_norm_backward_plumbing); VMAP_SUPPORT(native_layer_norm, native_layer_norm_batch_rule); diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 71dab6aed8955..47018283c8874 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -47,6 +47,8 @@ #include #include #include +#include +#include #include #include #include @@ -520,6 +522,22 @@ BatchNormBackend _select_batch_norm_backend( return BatchNormBackend::Cudnn; } + // HipDNN — independent of MIOpen + if (at::globalContext().userEnabledHipdnn() + && detail::getCUDAHooks().compiledWithHipDNN() + && input.is_cuda() + && input.dim() >= 3 + && input.dim() <= 5 + && input.scalar_type() != at::kDouble + && weight.scalar_type() == at::kFloat + && weight.defined() && bias.defined() + && ((running_mean.defined() && running_var.defined()) + || (!running_mean.defined() && !running_var.defined() && training)) + && input.is_contiguous(input.suggest_memory_format()) + ) { + return BatchNormBackend::Hipdnn; + } + // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen // See https://github.com/pytorch/pytorch/issues/64427. // non static variable is used to be able to change environment variable in runtime for testing @@ -634,6 +652,19 @@ std::tuple _batch_norm_impl_index( std::make_tuple(2)); } + if (backend == BatchNormBackend::Hipdnn) { + return std::tuple_cat( + at::hipdnn_batch_norm( + input.contiguous(input.suggest_memory_format()), + weight.contiguous(), + bias.contiguous(), + running_mean, + running_var, + training, momentum, eps), + std::tuple(reserve), + std::make_tuple(3)); + } + return std::tuple_cat( at::native_batch_norm( input, weight, bias, running_mean, running_var, training, momentum, eps), @@ -683,6 +714,8 @@ std::tuple _batch_norm_impl_index_backward( return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace); } else if (impl_index == 2) { return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon); + } else if (impl_index == 3) { + return at::hipdnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon); } TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index); } diff --git a/aten/src/ATen/native/Normalization.h b/aten/src/ATen/native/Normalization.h index 5eebb514a4690..edf1d0c1411f7 100644 --- a/aten/src/ATen/native/Normalization.h +++ b/aten/src/ATen/native/Normalization.h @@ -12,6 +12,7 @@ enum class BatchNormBackend { Native, Cudnn, Miopen, + Hipdnn, }; TORCH_API BatchNormBackend _select_batch_norm_backend(const Tensor& input, const Tensor& weight, const Tensor& bias, const Tensor& running_mean, const Tensor& running_var, bool training, double eps); diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 55d848610f5d1..05f7a22dc8fae 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include #include @@ -499,6 +501,12 @@ std::tuple _batch_norm_with_update_cuda( reserve = at::empty({0}, input.options().dtype(kByte)); std::tie(output, save_mean, save_var) = at::miopen_batch_norm(input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps); + } else if (backend == BatchNormBackend::Hipdnn) { + reserve = at::empty({0}, input.options().dtype(kByte)); + std::tie(output, save_mean, save_var) = + at::hipdnn_batch_norm( + input, weight, bias, running_mean, running_var, + /*training*/true, momentum, eps); } else { reserve = at::empty({0}, input.options().dtype(kByte)); std::tie(output, save_mean, save_var) = @@ -523,6 +531,11 @@ std::tuple _batch_norm_with_update_cuda_out( } else if (backend == BatchNormBackend::Miopen) { std::tie(out, save_mean, save_var) = at::miopen_batch_norm_out(out, save_mean, save_var, input, weight, bias, running_mean, running_var, /*training*/true, momentum, eps); + } else if (backend == BatchNormBackend::Hipdnn) { + std::tie(out, save_mean, save_var) = + at::hipdnn_batch_norm_out(out, save_mean, save_var, + input, weight, bias, running_mean, running_var, + /*training*/true, momentum, eps); } else { std::tie(out, save_mean, save_var) = batch_norm_cuda_out(input, weight_opt, bias_opt, running_mean, running_var, /*update*/true, momentum, eps, out, save_mean, save_var); @@ -563,6 +576,8 @@ std::tuple _new_batch_norm_backward_cuda( return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps, reserve); } else if (backend == BatchNormBackend::Miopen) { return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps); + } else if (backend == BatchNormBackend::Hipdnn) { + return at::hipdnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var, eps); } else { return batch_norm_backward_cuda(grad_output, input, weight, running_mean, running_var, save_mean, save_var, update, eps, grad_input_mask); } diff --git a/aten/src/ATen/native/hipdnn/BatchNorm_hipdnn.cpp b/aten/src/ATen/native/hipdnn/BatchNorm_hipdnn.cpp new file mode 100644 index 0000000000000..d9ca442943332 --- /dev/null +++ b/aten/src/ATen/native/hipdnn/BatchNorm_hipdnn.cpp @@ -0,0 +1,396 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +// TODO: Remove the condition on AT_ROCM_ENABLED entirely, +// don't build this file as part of CPU build. +#include + +#if !AT_ROCM_ENABLED() + +namespace at::native { + +// See Note [ATen preprocessor philosophy] + +std::tuple hipdnn_batch_norm( + const Tensor& input, + const Tensor& weight, + const std::optional& bias_opt, + const std::optional& running_mean_opt, + const std::optional& running_var_opt, + bool training, + double exponential_average_factor, + double epsilon) { + TORCH_CHECK(false, "hipdnn_batch_norm: ATen not compiled with ROCM support"); +} + +std::tuple hipdnn_batch_norm_backward( + const Tensor& input, + const Tensor& grad_output, + const Tensor& weight, + const std::optional& running_mean_opt, + const std::optional& running_var_opt, + const std::optional& save_mean_opt, + const std::optional& save_var_opt, + double epsilon) { + TORCH_CHECK( + false, "hipdnn_batch_norm_backward: ATen not compiled with ROCM support"); +} + +} // namespace at::native + +#elif !defined(USE_HIPDNN) // AT_ROCM_ENABLED but no hipDNN + +namespace at::native { + +std::tuple hipdnn_batch_norm( + const Tensor& input, + const Tensor& weight, + const std::optional& bias_opt, + const std::optional& running_mean_opt, + const std::optional& running_var_opt, + bool training, + double exponential_average_factor, + double epsilon) { + TORCH_CHECK(false, "hipdnn_batch_norm: not compiled with hipDNN support"); +} + +std::tuple hipdnn_batch_norm_backward( + const Tensor& input, + const Tensor& grad_output, + const Tensor& weight, + const std::optional& running_mean_opt, + const std::optional& running_var_opt, + const std::optional& save_mean_opt, + const std::optional& save_var_opt, + double epsilon) { + TORCH_CHECK( + false, "hipdnn_batch_norm_backward: not compiled with hipDNN support"); +} + +} // namespace at::native + +#else // AT_ROCM_ENABLED && USE_HIPDNN + +#include +#include +#include +#include +#include + +#include + +namespace at::native { + +namespace { + +Tensor expandScale(const Tensor& t, int64_t dim) { + std::vector size{1, t.numel()}; + while (static_cast(size.size()) < dim) { + size.emplace_back(1); + } + return t.view(size); +} + +} // namespace + +std::tuple hipdnn_batch_norm( + const Tensor& input_t, + const Tensor& weight_t, + const std::optional& bias_t_opt, + const std::optional& running_mean_t_opt, + const std::optional& running_var_t_opt, + bool training, + double exponential_average_factor, + double epsilon) { + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned bias_t_maybe_owned = + at::borrow_from_optional_tensor(bias_t_opt); + const Tensor& bias_t = *bias_t_maybe_owned; + const Tensor& running_mean_t = running_mean_t_opt.value_or(Tensor()); + const Tensor& running_var_t = running_var_t_opt.value_or(Tensor()); + + TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2}, + bias{bias_t, "bias", 3}, running_mean{running_mean_t, "running_mean", 4}, + running_var{running_var_t, "running_var", 5}; + CheckedFrom c = "hipdnn_batch_norm"; + + checkAllDefined(c, {input, weight, bias}); + if (!training) { + checkAllDefined(c, {running_mean, running_var}); + } + checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); + if (input->scalar_type() == ScalarType::Half || + input->scalar_type() == ScalarType::BFloat16) { + checkScalarType(c, weight, ScalarType::Float); + } else { + checkAllSameType(c, {input, weight}); + } + checkAllSameType(c, {weight, bias, running_mean, running_var}); + checkAllContiguous(c, {weight, bias, running_mean, running_var}); + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + checkDimRange(c, input, 2, 6 /* exclusive */); + auto num_features = input->size(1); + for (auto t : {weight, bias, running_mean, running_var}) { + if (t->defined()) { + checkNumel(c, t, num_features); + } + } + + auto output_t = at::empty_like( + input_t, input_t.options(), input_t.suggest_memory_format()); + TensorArg output{output_t, "output", 0}; + + auto handle = getHipdnnHandle(); + auto inputType = getHipdnnDataType(*input); + auto intermediateType = getHipdnnDataType(*weight); + Tensor save_mean, save_var; + + if (training) { + save_mean = at::empty({num_features}, weight_t.options()); + save_var = at::empty({num_features}, weight_t.options()); + + auto graph = std::make_shared(); + graph->set_io_data_type(inputType) + .set_intermediate_data_type(intermediateType) + .set_compute_data_type(hipdnn_frontend::DataType::FLOAT); + + auto input_attr = createTensorAttributes(*input); + auto weight_attr = + createTensorAttributes(expandScale(*weight, input->dim())); + auto bias_attr = createTensorAttributes(expandScale(*bias, input->dim())); + + auto bnAttributes = hipdnn_frontend::graph::BatchnormAttributes(); + auto epsilon_attr = + std::make_shared(); + epsilon_attr->set_value(epsilon); + bnAttributes.set_epsilon(epsilon_attr); + + std::shared_ptr prev_mean_attr; + std::shared_ptr prev_var_attr; + + if (running_mean->defined()) { + prev_mean_attr = + createTensorAttributes(expandScale(*running_mean, input->dim())); + prev_var_attr = + createTensorAttributes(expandScale(*running_var, input->dim())); + auto momentum_attr = + std::make_shared(); + momentum_attr->set_value(exponential_average_factor); + bnAttributes.set_previous_running_stats( + prev_mean_attr, prev_var_attr, momentum_attr); + } + + auto [y, savedMean, savedInvVar, nextMean, nextVar] = + graph->batchnorm(input_attr, weight_attr, bias_attr, bnAttributes); + y->set_output(true); + savedMean->set_output(true).set_data_type(intermediateType); + savedInvVar->set_output(true).set_data_type(intermediateType); + if (running_mean->defined()) { + nextMean->set_output(true).set_data_type(intermediateType); + nextVar->set_output(true).set_data_type(intermediateType); + } + + HIPDNN_FE_CHECK(graph->build(handle)); + + int64_t workspace_size = 0; + HIPDNN_FE_CHECK(graph->get_workspace_size(workspace_size)); + auto workspace = + at::empty({workspace_size}, input_t.options().dtype(at::kByte)); + + std::unordered_map variantPack; + variantPack[input_attr->get_uid()] = + const_cast(input->const_data_ptr()); + variantPack[weight_attr->get_uid()] = + const_cast(weight->const_data_ptr()); + variantPack[bias_attr->get_uid()] = + const_cast(bias->const_data_ptr()); + variantPack[y->get_uid()] = output->data_ptr(); + variantPack[savedMean->get_uid()] = save_mean.data_ptr(); + variantPack[savedInvVar->get_uid()] = save_var.data_ptr(); + if (running_mean->defined()) { + // running stats are updated in-place: prev and next point to the same + // tensor so the graph overwrites the running stats during execution. + variantPack[prev_mean_attr->get_uid()] = running_mean->data_ptr(); + variantPack[prev_var_attr->get_uid()] = running_var->data_ptr(); + variantPack[nextMean->get_uid()] = running_mean->data_ptr(); + variantPack[nextVar->get_uid()] = running_var->data_ptr(); + } + + HIPDNN_FE_CHECK(graph->execute(handle, variantPack, workspace.data_ptr())); + + } else { + save_mean = at::empty({0}, weight_t.options()); + save_var = at::empty({0}, weight_t.options()); + + auto graph = std::make_shared(); + graph->set_io_data_type(inputType) + .set_intermediate_data_type(intermediateType) + .set_compute_data_type(hipdnn_frontend::DataType::FLOAT); + + auto input_attr = createTensorAttributes(*input); + auto weight_attr = + createTensorAttributes(expandScale(*weight, input->dim())); + auto bias_attr = createTensorAttributes(expandScale(*bias, input->dim())); + auto mean_attr = + createTensorAttributes(expandScale(*running_mean, input->dim())); + auto variance_attr = + createTensorAttributes(expandScale(*running_var, input->dim())); + auto epsilon_attr = + std::make_shared(); + epsilon_attr->set_value(epsilon); + + auto bnAttributes = + hipdnn_frontend::graph::BatchnormInferenceAttributesVarianceExt(); + auto output_attr = graph->batchnorm_inference_variance_ext( + input_attr, + mean_attr, + variance_attr, + weight_attr, + bias_attr, + epsilon_attr, + bnAttributes); + output_attr->set_output(true); + + HIPDNN_FE_CHECK(graph->build(handle)); + + int64_t workspace_size = 0; + HIPDNN_FE_CHECK(graph->get_workspace_size(workspace_size)); + auto workspace = + at::empty({workspace_size}, input_t.options().dtype(at::kByte)); + + std::unordered_map variantPack; + variantPack[input_attr->get_uid()] = + const_cast(input->const_data_ptr()); + variantPack[weight_attr->get_uid()] = + const_cast(weight->const_data_ptr()); + variantPack[bias_attr->get_uid()] = + const_cast(bias->const_data_ptr()); + variantPack[mean_attr->get_uid()] = + const_cast(running_mean->const_data_ptr()); + variantPack[variance_attr->get_uid()] = + const_cast(running_var->const_data_ptr()); + variantPack[output_attr->get_uid()] = output->data_ptr(); + + HIPDNN_FE_CHECK(graph->execute(handle, variantPack, workspace.data_ptr())); + } + + return std::tuple{output_t, save_mean, save_var}; +} + +std::tuple hipdnn_batch_norm_backward( + const Tensor& input_t, + const Tensor& grad_output_t, + const Tensor& weight_t, + // Unused: but we require them to be passed so that double backwards + // has access + const std::optional& running_mean_opt, + const std::optional& running_var_opt, + const std::optional& save_mean_t_opt, + const std::optional& save_var_t_opt, + double epsilon) { + // See [Note: hacky wrapper removal for optional tensor] + const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor()); + const Tensor& save_var_t = save_var_t_opt.value_or(Tensor()); + + auto grad_output_contig = + grad_output_t.contiguous(input_t.suggest_memory_format()); + TensorArg input{input_t, "input", 1}, + grad_output{grad_output_contig, "grad_output", 2}, + weight{weight_t, "weight", 3}, save_mean{save_mean_t, "save_mean", 4}, + save_var{save_var_t, "save_var", 5}; + CheckedFrom c = "hipdnn_batch_norm_backward"; + + checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); + checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); + if (input->scalar_type() == ScalarType::Half || + input->scalar_type() == ScalarType::BFloat16) { + checkScalarType(c, weight, ScalarType::Float); + } else { + checkAllSameType(c, {input, weight}); + } + checkAllSameType(c, {input, grad_output}); + checkAllSameType(c, {weight, save_mean, save_var}); + checkAllContiguous(c, {save_mean, save_var}); + TORCH_CHECK(input->is_contiguous(input->suggest_memory_format())); + TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format())); + checkDimRange(c, input, 2, 6 /* exclusive */); + checkSameSize(c, input, grad_output); + auto num_features = input->size(1); + for (auto t : {weight, save_mean, save_var}) { + checkNumel(c, t, num_features); + } + + auto grad_input_t = at::empty( + input->sizes(), input->options(), input->suggest_memory_format()); + auto grad_weight_t = at::empty(weight->sizes(), weight->options()); + auto grad_bias_t = at::empty(weight->sizes(), weight->options()); + + auto handle = getHipdnnHandle(); + auto inputType = getHipdnnDataType(*input); + auto intermediateType = getHipdnnDataType(*weight); + + auto graph = std::make_shared(); + graph->set_io_data_type(inputType) + .set_intermediate_data_type(intermediateType) + .set_compute_data_type(hipdnn_frontend::DataType::FLOAT); + + auto dy_attr = createTensorAttributes(*grad_output); + auto input_attr = createTensorAttributes(*input); + auto weight_attr = createTensorAttributes(expandScale(*weight, input->dim())); + auto savedMeanAttr = + createTensorAttributes(expandScale(*save_mean, input->dim())); + auto savedInvVarAttr = + createTensorAttributes(expandScale(*save_var, input->dim())); + + auto bnBwdAttributes = hipdnn_frontend::graph::BatchnormBackwardAttributes(); + bnBwdAttributes.set_saved_mean_and_inv_variance( + savedMeanAttr, savedInvVarAttr); + + auto [dx, dscale, dbias] = graph->batchnorm_backward( + dy_attr, input_attr, weight_attr, bnBwdAttributes); + dx->set_output(true); + dscale->set_output(true).set_data_type(intermediateType); + dbias->set_output(true).set_data_type(intermediateType); + + HIPDNN_FE_CHECK(graph->build(handle)); + + int64_t workspace_size = 0; + HIPDNN_FE_CHECK(graph->get_workspace_size(workspace_size)); + auto workspace = + at::empty({workspace_size}, input_t.options().dtype(at::kByte)); + + std::unordered_map variantPack; + variantPack[dy_attr->get_uid()] = + const_cast(grad_output->const_data_ptr()); + variantPack[input_attr->get_uid()] = + const_cast(input->const_data_ptr()); + variantPack[weight_attr->get_uid()] = + const_cast(weight->const_data_ptr()); + variantPack[savedMeanAttr->get_uid()] = + const_cast(save_mean->const_data_ptr()); + variantPack[savedInvVarAttr->get_uid()] = + const_cast(save_var->const_data_ptr()); + variantPack[dx->get_uid()] = grad_input_t.data_ptr(); + variantPack[dscale->get_uid()] = grad_weight_t.data_ptr(); + variantPack[dbias->get_uid()] = grad_bias_t.data_ptr(); + + HIPDNN_FE_CHECK(graph->execute(handle, variantPack, workspace.data_ptr())); + + return std::tuple{ + grad_input_t, grad_weight_t, grad_bias_t}; +} + +} // namespace at::native + +#endif diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9e39eda0368fe..9cc648b5526b9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4211,6 +4211,16 @@ CUDA: miopen_batch_norm_backward autogen: miopen_batch_norm_backward.out +- func: hipdnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: hipdnn_batch_norm + autogen: hipdnn_batch_norm.out + +- func: hipdnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + dispatch: + CUDA: hipdnn_batch_norm_backward + autogen: hipdnn_batch_norm_backward.out + - func: miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor dispatch: CUDA: miopen_convolution diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 61882630021d8..1215a6df576fe 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -876,6 +876,8 @@ aten::hardshrink_backward aten::hardshrink_backward.grad_input aten::hash_tensor aten::hash_tensor.out +aten::hipdnn_batch_norm +aten::hipdnn_batch_norm.out aten::histc aten::histc.out aten::histogram.bin_ct diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 41699c7b86ee6..a0fdb06bab1bd 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -91,6 +91,8 @@ ("aten::solve.solution", datetime.date(9999, 1, 1)), ("aten::_solve_helper", datetime.date(9999, 1, 1)), ("aten::_convolution_nogroup", datetime.date(9999, 1, 1)), + ("aten::hipdnn_batch_norm", datetime.date(9999, 1, 1)), + ("aten::hipdnn_batch_norm_backward", datetime.date(9999, 1, 1)), ("aten::miopen_convolution_backward", datetime.date(9999, 1, 1)), ("aten::miopen_convolution_backward_bias", datetime.date(9999, 1, 1)), ("aten::miopen_convolution_backward_input", datetime.date(9999, 1, 1)), diff --git a/test/test_nn.py b/test/test_nn.py index 3a5258d8ff9f6..7f9aa4f4fb5dd 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -41,7 +41,7 @@ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ skipIfTorchDynamo, gcIfJetson, set_default_dtype from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ - _get_torch_rocm_version + TEST_HIPDNN, _get_torch_rocm_version from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \ module_tests, criterion_tests, loss_reference_fns, _create_basic_net, \ ctcloss_reference, get_new_module_tests, single_batch_reference_fn, _test_bfloat16_ops, _test_module_empty_input @@ -5111,6 +5111,27 @@ def run_test(input, grad_output): grad = grad.permute(0, 2, 1, 3) run_test(input, grad) + @unittest.skipIf(not TEST_HIPDNN, "hipDNN not available") + def test_batchnorm_hipdnn_backend_selection(self): + # impl_index: 0=Native, 1=cuDNN, 2=MIOpen, 3=hipDNN + c = 16 + bn = torch.nn.BatchNorm2d(c).cuda() + input = torch.randn(4, c, 8, 8, device="cuda") + + # With hipdnn enabled, should select hipDNN backend (index 3) + with torch.backends.hipdnn.flags(enabled=True): + _, _, _, _, impl_index = torch._batch_norm_impl_index( + input, bn.weight, bn.bias, bn.running_mean, bn.running_var, + bn.training, bn.momentum, bn.eps, torch.backends.cudnn.enabled) + self.assertEqual(impl_index, 3) + + # With hipdnn disabled, should fall back to MIOpen (index 2) + with torch.backends.hipdnn.flags(enabled=False): + _, _, _, _, impl_index = torch._batch_norm_impl_index( + input, bn.weight, bn.bias, bn.running_mean, bn.running_var, + bn.training, bn.momentum, bn.eps, torch.backends.cudnn.enabled) + self.assertEqual(impl_index, 2) + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") def test_batchnorm_cudnn_half(self): # THNN @@ -5222,6 +5243,8 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @parametrize_test("backend", ["default", "hipdnn"] if TEST_HIPDNN else ["default"], + name_fn=lambda x: x if x == "hipdnn" else "") @parametrize_test("dims", [2, 3], name_fn=lambda x: f"{x}D") @parametrize_test("mode", ["train", "inference"], name_fn=lambda x: x) @parametrize_test( @@ -5258,7 +5281,7 @@ def test_batchnorm_buffer_update_when_stats_are_not_tracked(self): ], name_fn=lambda f, b, m, t: f"{f}_vs_{b}{'_mixed' if m else ''}_{dtype_name(t)}" ) - def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype): + def test_batchnorm(self, backend, dims, mode, memory_format, ref_backend, mixed, dtype): if torch.version.cuda: if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16", "test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16", @@ -5269,13 +5292,16 @@ def test_batchnorm(self, dims, mode, memory_format, ref_backend, mixed, dtype): if torch.version.hip: if self._testMethodName in ("test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16", - "test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16") \ + "test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16", + "test_batchnorm_hipdnn_2D_train_NCHW_vs_native_mixed_bfloat16", + "test_batchnorm_hipdnn_3D_train_NCHW_vs_native_mixed_bfloat16") \ and _get_torch_rocm_version() >= (6, 4): # https://github.com/pytorch/pytorch/issues/156513 self.skipTest("bfloat16 NCHW train failed due to native tolerance issue") - if self._testMethodName == "test_batchnorm_3D_train_NCHW_vs_native_mixed_float16": - self.skipTest("3D float16 NCHW train failed on ROCm") + if self._testMethodName in ("test_batchnorm_3D_train_NCHW_vs_native_mixed_float16", + "test_batchnorm_hipdnn_3D_train_NCHW_vs_native_mixed_float16"): + self.skipTest("3D float16 NCHW train failed on ROCm due to native tolerance issue") if dims == 3 and memory_format in ("NHWC", "NCHW"): memory_format = memory_format + "3D" @@ -5387,10 +5413,11 @@ def _inference(memory_format_name, ref_backend, mixed, dtype): ref_out = ref_mod(ref_inp) self.assertEqual(out, ref_out) - if mode == "train": - _train(memory_format, ref_backend, mixed, dtype) - else: - _inference(memory_format, ref_backend, mixed, dtype) + with torch.backends.hipdnn.flags(enabled=True) if backend == "hipdnn" else contextlib.nullcontext(): + if mode == "train": + _train(memory_format, ref_backend, mixed, dtype) + else: + _inference(memory_format, ref_backend, mixed, dtype) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_batchnorm_nhwc_cuda(self): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index fa611a88889b0..3d40990aadf31 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -2854,6 +2854,16 @@ - name: mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) +# hipdnn +- name: hipdnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + input, weight, bias: "grad.defined() ? (training ? hipdnn_batch_norm_backward(input, grad.contiguous(input.suggest_memory_format()), weight, running_mean, running_var, result1, result2, epsilon) : native_batch_norm_backward(grad, input, weight, running_mean, running_var, result1, result2, training, epsilon, grad_input_mask)) : std::tuple()" + result0: batch_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, running_mean, running_var, result1, result2, training, epsilon) + +- name: hipdnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + save_mean: not_implemented("hipdnn_batch_norm_backward save_mean") + save_var: not_implemented("hipdnn_batch_norm_backward save_var") + input, weight, grad_output: batchnorm_double_backward(input, weight, grads[0], grads[1], grads[2], grad_output, running_mean, running_var, true, epsilon, save_mean, save_var, grad_input_mask) + # mkldnn - name: mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor self, weight, bias: "grad.defined() ? convolution_backward_symint(grad, self, weight, bias->sym_sizes(), stride, padding, dilation, /*transposed=*/ false, /*output_padding=*/ std::vector(padding.size(), 0), groups, grad_input_mask) : std::tuple()" diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 2553f9e60de96..4eeacc671ab75 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -335,6 +335,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.cudnn_batch_norm, aten.cudnn_batch_norm_backward, aten.miopen_batch_norm_backward, + aten.hipdnn_batch_norm_backward, aten.deg2rad, aten.deg2rad_, aten.detach, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 75087863b52f9..97d0f051409ab 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2627,6 +2627,32 @@ def miopen_batch_norm_backward( ) +@register_decomposition(aten.hipdnn_batch_norm_backward) +@out_wrapper("out0", "out1", "out2") +def hipdnn_batch_norm_backward( + input: Tensor, + grad_output: Tensor, + weight: Tensor, + running_mean: Tensor | None, + running_var: Tensor | None, + save_mean: Tensor | None, + save_var: Tensor | None, + epsilon: float, +): + return aten.native_batch_norm_backward( + grad_output, + input, + weight, + running_mean, + running_var, + save_mean, + save_var, + True, + epsilon, + [True, True, True], + ) + + @register_decomposition(aten.cudnn_batch_norm_backward) @out_wrapper("out0", "out1", "out2") def cudnn_batch_norm_backward( diff --git a/torch/_decomp/decompositions_for_jvp.py b/torch/_decomp/decompositions_for_jvp.py index 45140652306c6..8258b0b4874fe 100644 --- a/torch/_decomp/decompositions_for_jvp.py +++ b/torch/_decomp/decompositions_for_jvp.py @@ -342,3 +342,4 @@ def batch_norm_backward( _register_jit_decomposition_for_jvp(torch.ops.aten.cudnn_batch_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.batch_norm_backward.default) _register_jit_decomposition_for_jvp(torch.ops.aten.miopen_batch_norm_backward.default) +_register_jit_decomposition_for_jvp(torch.ops.aten.hipdnn_batch_norm_backward.default) diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 3cdcae081b808..fc2a206e4027d 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -3069,7 +3069,8 @@ Call this whenever a new thread is created in order to propagate values from py::enum_(py_module, "_BatchNormBackend") .value("Native", at::native::BatchNormBackend::Native) .value("Cudnn", at::native::BatchNormBackend::Cudnn) - .value("Miopen", at::native::BatchNormBackend::Miopen); + .value("Miopen", at::native::BatchNormBackend::Miopen) + .value("Hipdnn", at::native::BatchNormBackend::Hipdnn); py_module.def( "_select_batch_norm_backend", From fe50fbc0508906dcdcf0bd2b487580278c268a2d Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Mon, 30 Mar 2026 14:48:25 -0700 Subject: [PATCH 04/10] [ROCm] Stop hipifying native/cudnn/ and native/quantized/cudnn/ files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These files are guarded by `#if AT_CUDNN_ENABLED()` which is always 0 on ROCm, so only stub implementations compile. Hipify was making text substitutions (cudnnHandle_t → miopenHandle_t, etc.) to code that is entirely dead on ROCm. Include fixes needed to compile without hipify: - RNN.cpp: move CUDAEvent.h, CUDAGraphsUtils.cuh, Exceptions.h into the `#else // AT_CUDNN_ENABLED()` block (only used by real impl) - LossCTC.cpp: remove unused CUDAGraphsUtils.cuh include - BatchNorm.cpp, Module.cpp, attention.cu, attention_backward.cu: remove `#ifdef __HIP_PLATFORM_AMD__` guards that selected hipified header paths (cudnn/hip/MHA.h, cudnn/hip/BatchNorm.h) — use the originals directly since hipify no longer runs on these files The quantized/cudnn/ files additionally had redundant `#ifdef USE_CUDA` guards wrapping the entire file. These are only compiled in CUDA/ROCm builds (gated by cmake), so the guards were dead code. Authored with Claude. --- aten/src/ATen/CMakeLists.txt | 4 ++-- aten/src/ATen/native/cudnn/BatchNorm.cpp | 4 ---- aten/src/ATen/native/cudnn/LossCTC.cpp | 1 - aten/src/ATen/native/cudnn/RNN.cpp | 6 +++--- aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp | 2 -- aten/src/ATen/native/quantized/cudnn/Conv.cpp | 2 -- aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp | 2 -- aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp | 2 -- aten/src/ATen/native/quantized/cudnn/Linear.cpp | 2 -- aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp | 2 -- aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp | 2 -- aten/src/ATen/native/quantized/cudnn/utils.h | 2 -- aten/src/ATen/native/transformers/cuda/attention.cu | 4 ---- .../src/ATen/native/transformers/cuda/attention_backward.cu | 4 ---- tools/amd_build/build_amd.py | 2 -- torch/csrc/Module.cpp | 4 ---- 16 files changed, 5 insertions(+), 40 deletions(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 92cd4827d0901..50fd98deaaa1d 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -644,7 +644,7 @@ if(USE_ROCM) ${native_sparse_hip_cpp} ${native_quantized_hip_cpp} ${native_transformers_hip_cpp} - ${native_quantized_cudnn_hip_cpp} + ${native_quantized_cudnn_cpp} ${hip_cpp} ${native_hip_cpp} ${native_hip_linalg_cpp} @@ -652,7 +652,7 @@ if(USE_ROCM) ${ATen_HIP_SRCS} ${native_miopen_cpp} ${native_hipdnn_cpp} - ${native_cudnn_hip_cpp} + ${native_cudnn_cpp} ${miopen_cpp} ${all_hip_cpp} ) diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 7556de8245af1..48e120ac74e86 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -3,11 +3,7 @@ #include #include -#ifdef __HIP_PLATFORM_AMD__ -#include -#else #include -#endif #if !AT_CUDNN_ENABLED() diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index 1b617199330fb..6a149c13e1bcf 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -2,7 +2,6 @@ #include #include #include -#include #if AT_CUDNN_ENABLED() #include #endif diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index a45fda2b22db4..1a5dbd7e3eac1 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -4,14 +4,11 @@ #include #include #include -#include -#include #include #include #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -113,6 +110,9 @@ Tensor _cudnn_init_dropout_state( #else // AT_CUDNN_ENABLED() +#include +#include +#include #include namespace at::native { diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index a81332e5999ee..761a33cb3349b 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -1,4 +1,3 @@ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -257,4 +256,3 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at::native #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Conv.cpp b/aten/src/ATen/native/quantized/cudnn/Conv.cpp index 6424000594ee9..f4d6583a56cf2 100644 --- a/aten/src/ATen/native/quantized/cudnn/Conv.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Conv.cpp @@ -1,4 +1,3 @@ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -402,4 +401,3 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp index 1e1811a0b2c45..10c6f5184c0f7 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp @@ -1,4 +1,3 @@ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -208,4 +207,3 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { } // namespace at::native #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp index fbb4a1fe94111..3d880bb13df98 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvUnpackImpl.cpp @@ -1,4 +1,3 @@ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -20,4 +19,3 @@ template std::tuple> PackedConvWeightCudnn 2>::unpack(); #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/Linear.cpp b/aten/src/ATen/native/quantized/cudnn/Linear.cpp index 230850998fda1..7c4e0694ea4c5 100644 --- a/aten/src/ATen/native/quantized/cudnn/Linear.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Linear.cpp @@ -1,4 +1,3 @@ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -367,4 +366,3 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp index 3b01841c4aa87..9587ee4a85b77 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearPrepack.cpp @@ -1,4 +1,3 @@ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -56,4 +55,3 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) { #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp index 40088052cd151..7ed022986d02f 100644 --- a/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp +++ b/aten/src/ATen/native/quantized/cudnn/LinearUnpackImpl.cpp @@ -1,4 +1,3 @@ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -15,4 +14,3 @@ std::tuple> PackedLinearWeightCudnn::unpac } #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/quantized/cudnn/utils.h b/aten/src/ATen/native/quantized/cudnn/utils.h index 0b46f743fa68d..b94271558e13a 100644 --- a/aten/src/ATen/native/quantized/cudnn/utils.h +++ b/aten/src/ATen/native/quantized/cudnn/utils.h @@ -3,7 +3,6 @@ This file contains some of the auxiliary functions used by both Conv.cpp & Linear.cpp (introduced in a later PR) */ -#ifdef USE_CUDA #include // for the definition of AT_CUDNN_ENABLED #if AT_CUDNN_ENABLED() @@ -312,4 +311,3 @@ inline void filterEngineConfigs( } // cudnn_utils #endif // AT_CUDNN_ENABLED -#endif // USE_CUDA diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 69ecf31df0586..b47a45aec7d28 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -56,11 +56,7 @@ #include #endif -#ifdef __HIP_PLATFORM_AMD__ -#include -#else #include -#endif #include diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 183f99e975cda..b53e6bee983a2 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -59,11 +59,7 @@ #endif #endif -#ifdef __HIP_PLATFORM_AMD__ -#include -#else #include -#endif namespace at::native { diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index 7f9289cae498d..b72ffe675c19e 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -89,8 +89,6 @@ # Keep this synchronized with is_pytorch_file in hipify_python.py "aten/src/ATen/cuda/*", "aten/src/ATen/native/cuda/*", - "aten/src/ATen/native/cudnn/*", - "aten/src/ATen/native/quantized/cudnn/*", "aten/src/ATen/native/nested/cuda/*", "aten/src/ATen/native/sparse/cuda/*", "aten/src/ATen/native/quantized/cuda/*", diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index fc2a206e4027d..ddac0c1552c03 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -118,12 +118,8 @@ #include #include #include -#ifdef __HIP_PLATFORM_AMD__ -#include -#else #include #endif -#endif #ifdef USE_XPU #include From 3368247e5f7a544dc8e44c59ff7257a427d2a48f Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Thu, 16 Apr 2026 17:21:03 -0700 Subject: [PATCH 05/10] Add aotriton.images/ to .gitignore Building with USE_FLASH_ATTENTION=ON on ROCm copies precompiled AOTriton kernel images into torch/lib/aotriton.images/. These are binary GPU kernels for flash and efficient attention, shipped precompiled in the ROCm SDK. Co-Authored-By: Claude Opus 4 (1M context) --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 5105977ae52c0..2c6d782fd98e1 100644 --- a/.gitignore +++ b/.gitignore @@ -106,6 +106,7 @@ torch/lib/*.lib torch/lib/*.pdb torch/lib/*.so* torch/lib/protobuf*.pc +torch/lib/aotriton.images/ torch/lib/build torch/lib/caffe2/ torch/lib/cmake From a25304dcb99407e7f6ffbd97bd921117229b6945 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 14 Apr 2026 12:16:51 -0700 Subject: [PATCH 06/10] Extract compute_matching_strides from alloc_with_matching_layout Factor out the stride computation logic from alloc_with_matching_layout into a standalone compute_matching_strides function that returns strides without allocating a tensor. This allows callers that only need the output stride metadata (e.g., graph-based support checks) to avoid unnecessary GPU tensor allocations. For the same-size case, delegates to infer_dense_strides to match empty_like's compaction behavior on non-dense inputs. For different sizes, computes dense strides preserving the reference tensor's dimension ordering. Co-Authored-By: Claude Opus 4.6 --- aten/src/ATen/native/transformers/sdp_utils.h | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/transformers/sdp_utils.h b/aten/src/ATen/native/transformers/sdp_utils.h index 18f13781944f8..fca993871cc77 100644 --- a/aten/src/ATen/native/transformers/sdp_utils.h +++ b/aten/src/ATen/native/transformers/sdp_utils.h @@ -1,34 +1,32 @@ #pragma once #include +#include #include namespace at::native { -void alloc_with_matching_layout( - const Tensor& q, - Tensor& output, - const std::vector& shape) { - TORCH_INTERNAL_ASSERT( - shape.size() == q.sizes().size(), - "SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); - - if (std::equal(q.sizes().begin(), q.sizes().end(), shape.begin())) { - output = at::empty_like(q); - return; +// Compute dense strides that preserve the dimension ordering of ref_strides +// for the given shape. When ref_sizes == shape, uses infer_dense_strides to +// match empty_like's behavior (compacting non-dense gaps). +inline std::vector compute_matching_strides( + IntArrayRef ref_sizes, + IntArrayRef ref_strides, + IntArrayRef shape) { + if (ref_sizes.equals(shape)) { + return infer_dense_strides(ref_sizes, ref_strides); } // get the "fill order," which is just an argsort on the strides std::vector fill_order(shape.size()); std::iota(fill_order.begin(), fill_order.end(), 0); - const auto q_strides = q.strides(); // note: why INT64_MAX instead of 1. // When Q's strides include 0, e.g. (0, 0, 128, 1), mapping stride 0 to 1 leads to // fill_order of [0, 1, 3, 2], i.e. the output strides are [1, 8, 1024, 16]. // To match output strides with Q, use INT64_MAx so that broadcast dims come last in fill_order. std::stable_sort( - fill_order.begin(), fill_order.end(), [&q_strides](int idx1, int idx2) { - int64_t s1 = q_strides[idx1] ? q_strides[idx1] : INT64_MAX; - int64_t s2 = q_strides[idx2] ? q_strides[idx2] : INT64_MAX; + fill_order.begin(), fill_order.end(), [&ref_strides](int idx1, int idx2) { + int64_t s1 = ref_strides[idx1] ? ref_strides[idx1] : INT64_MAX; + int64_t s2 = ref_strides[idx2] ? ref_strides[idx2] : INT64_MAX; return s1 < s2; }); std::vector ordered_strides(shape.size()); @@ -37,7 +35,19 @@ void alloc_with_matching_layout( ordered_strides[dim_idx] = current_stride; current_stride *= shape[dim_idx]; } - output = at::empty_strided(at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), q.options()); + return ordered_strides; +} + +void alloc_with_matching_layout( + const Tensor& q, + Tensor& output, + const std::vector& shape) { + TORCH_INTERNAL_ASSERT( + shape.size() == q.sizes().size(), + "SDPA alloc_with_matching_layout got requested shape ndim != q ndim"); + + auto strides = compute_matching_strides(q.sizes(), q.strides(), shape); + output = at::empty_strided(shape, strides, q.options()); } void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { From 6f5ae939405601243e136c37ad423fe17602a4b2 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Thu, 7 May 2026 13:32:21 -0700 Subject: [PATCH 07/10] [ROCm] Plumb user-supplied SDPA scale through dispatch Add an optional 'scale' field to sdp::sdp_params and populate it from every dispatch entry point (_fused_sdp_choice_cpp/_cuda/_xpu, the transformer_encoder helper, and the SDPAParams Python binding). The hipDNN backend reads this in check_cudnn_sdpa_support so the support query and the eventual graph build see the actual scale the user passed, instead of always defaulting to 1/sqrt(head_dim). --- aten/src/ATen/native/mkldnn/xpu/Attention.cpp | 2 +- .../ATen/native/transformers/attention.cpp | 2 +- .../native/transformers/cuda/attention.cu | 4 ++-- .../ATen/native/transformers/sdp_utils_cpp.h | 1 + torch/csrc/Module.cpp | 19 +++++++++++++++---- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp index 581facd59d9b7..5537f48b63e53 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Attention.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Attention.cpp @@ -262,7 +262,7 @@ int64_t _fused_sdp_choice_xpu( std::optional scale, bool enable_gqa) { sdp::sdp_params kernel_params{ - query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; + query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa, scale}; auto backend = select_sdp_backend_xpu(kernel_params); if (backend == sdp::SDPBackend::error) { diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index eecfe1058c2c4..74e2838851880 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -432,7 +432,7 @@ std::tuple native_multi_head_attention_cpu( int64_t _fused_sdp_choice_cpp(const Tensor& query_, const Tensor& key, const Tensor& value, const std::optional& attn_mask_, double dropout_p, bool is_causal, std::optional scale, bool enable_gqa){ - sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; + sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa, scale}; auto backend = sdp::select_sdp_backend_cpp(kernel_params); if (backend == sdp::SDPBackend::error) { TORCH_CHECK( diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index b47a45aec7d28..6367a6770d484 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -784,7 +784,7 @@ std::tuple native_multi_head_attention_cuda( auto k = key.view({key.size(0), -1, num_head, dim_per_head}).transpose(1, 2); auto v = value.view({value.size(0), -1, num_head, dim_per_head}).transpose(1, 2); - sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false}; + sdp::sdp_params kernel_params{q, k, v, mask, 0.0, false, false, /*scale=*/std::nullopt}; auto backend = select_sdp_backend(kernel_params); // strides from packed projection for nested tensors when seq_len is 1 will be // and will trigger a contiguous call in the kernel, so we prevent this @@ -1276,7 +1276,7 @@ std::tuple _scaled_dot_product_efficient_attenti int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Tensor& value, const std::optional& attn_mask_, double dropout_p, bool is_causal, std::optional scale, bool enable_gqa){ - sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa}; + sdp::sdp_params kernel_params{query_, key, value, attn_mask_, dropout_p, is_causal, enable_gqa, scale}; auto backend = select_sdp_backend(kernel_params); if (backend == sdp::SDPBackend::error) { TORCH_CHECK( diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index eb6a92dd5411a..4988c83abb178 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -42,6 +42,7 @@ struct sdp_params { double dropout; bool is_causal; bool enable_gqa; + std::optional scale; }; SDPBackend select_sdp_backend_cpp(sdp_params const& kernel_params); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ddac0c1552c03..adaceabb07f94 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2673,7 +2673,8 @@ Call this whenever a new thread is created in order to propagate values from std::optional attn_mask, double dropout, bool is_causal, - bool enable_gqa) { + bool enable_gqa, + std::optional scale) { return sdp::sdp_params{ query, key, @@ -2681,15 +2682,25 @@ Call this whenever a new thread is created in order to propagate values from std::move(attn_mask), dropout, is_causal, - enable_gqa}; - })) + enable_gqa, + scale}; + }), + py::arg("query"), + py::arg("key"), + py::arg("value"), + py::arg("attn_mask"), + py::arg("dropout"), + py::arg("is_causal"), + py::arg("enable_gqa"), + py::arg("scale") = std::nullopt) .def_readonly("query", &sdp::sdp_params::query) .def_readonly("key", &sdp::sdp_params::key) .def_readonly("value", &sdp::sdp_params::value) .def_readonly("attn_mask", &sdp::sdp_params::attn_mask) .def_readonly("dropout", &sdp::sdp_params::dropout) .def_readonly("is_causal", &sdp::sdp_params::is_causal) - .def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa); + .def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa) + .def_readonly("scale", &sdp::sdp_params::scale); py::enum_( py_module, From 9bb87526747d20d96c79e2b0f443d981fa8355d6 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 14 Apr 2026 10:44:25 -0700 Subject: [PATCH 08/10] [ROCm] Integrate hipDNN as an SDPA backend Uses shim headers to re-route CUDA includes to hip/hipdnn. Some differences still need conditional logic in the source files: - hipDNN requires `scale` value to be set on the graph as a constant, while cuDNN accepts it through a host buffer at execute time. This requires an extra field in the cache key for the scale value. - code dependant on cuDNN version requires logic to handle hipDNN - ragged/nested tensors aren't supported on hipDNN due to missing APIs: - set_seq_len_{q,k,v} - set_padding_mask - set_ragged_offset - cuDNN constrains are checked in pytorch, rather than through API queries - increases coupling, but *does* allow requirements to be checked symbolically without requiring concrete dimension values --- aten/src/ATen/CMakeLists.txt | 1 - .../hip_compat/include/ATen/cuda/Exceptions.h | 8 + .../hip_compat/include/ATen/cudnn/Handle.h | 13 + .../include/c10/cuda/impl/cuda_cmake_macros.h | 5 + aten/src/ATen/hip_compat/include/cuda.h | 5 + .../ATen/hip_compat/include/cuda_runtime.h | 6 + .../hip_compat/include/cuda_runtime_api.h | 36 ++ .../ATen/hip_compat/include/cudnn_frontend.h | 62 +++ aten/src/ATen/native/cudnn/MHA.cpp | 482 +++++++++++++----- aten/src/ATen/native/cudnn/MHA.h | 7 + .../ATen/native/transformers/attention.cpp | 12 +- .../native/transformers/cuda/sdp_utils.cpp | 85 ++- caffe2/CMakeLists.txt | 6 + test/test_transformers.py | 46 ++ 14 files changed, 640 insertions(+), 134 deletions(-) create mode 100644 aten/src/ATen/hip_compat/include/ATen/cuda/Exceptions.h create mode 100644 aten/src/ATen/hip_compat/include/ATen/cudnn/Handle.h create mode 100644 aten/src/ATen/hip_compat/include/c10/cuda/impl/cuda_cmake_macros.h create mode 100644 aten/src/ATen/hip_compat/include/cuda.h create mode 100644 aten/src/ATen/hip_compat/include/cuda_runtime.h create mode 100644 aten/src/ATen/hip_compat/include/cuda_runtime_api.h create mode 100644 aten/src/ATen/hip_compat/include/cudnn_frontend.h diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 50fd98deaaa1d..2b26951aa0e79 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -182,7 +182,6 @@ file(GLOB native_hip_cpp "native/hip/*.cpp") file(GLOB native_hip_linalg_cpp "native/hip/linalg/*.cpp") file(GLOB native_miopen_cpp "native/miopen/*.cpp") file(GLOB native_hipdnn_cpp "native/hipdnn/*.cpp") -file(GLOB native_cudnn_hip_cpp "native/cudnn/hip/*.cpp") file(GLOB native_nested_hip_hip "native/nested/hip/*.hip") file(GLOB native_nested_hip_cpp "native/nested/hip/*.cpp") file(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip") diff --git a/aten/src/ATen/hip_compat/include/ATen/cuda/Exceptions.h b/aten/src/ATen/hip_compat/include/ATen/cuda/Exceptions.h new file mode 100644 index 0000000000000..11c156cb81c03 --- /dev/null +++ b/aten/src/ATen/hip_compat/include/ATen/cuda/Exceptions.h @@ -0,0 +1,8 @@ +#pragma once + +// Shim of `` for HIP builds. Defines +// AT_CUDNN_FRONTEND_CHECK in terms of hipDNN's check macro. + +#include + +#define AT_CUDNN_FRONTEND_CHECK(e) HIPDNN_FE_CHECK(e) diff --git a/aten/src/ATen/hip_compat/include/ATen/cudnn/Handle.h b/aten/src/ATen/hip_compat/include/ATen/cudnn/Handle.h new file mode 100644 index 0000000000000..76399e6b78eb2 --- /dev/null +++ b/aten/src/ATen/hip_compat/include/ATen/cudnn/Handle.h @@ -0,0 +1,13 @@ +#pragma once + +// Shim of `` for HIP builds. Forwards the cuDNN handle +// symbols to their hipDNN equivalents so non-hipified files compile against +// the cuDNN-named API. + +#include +#include + +using cudnnHandle_t = hipdnnHandle_t; +inline cudnnHandle_t getCudnnHandle() { + return at::native::getHipdnnHandle(); +} diff --git a/aten/src/ATen/hip_compat/include/c10/cuda/impl/cuda_cmake_macros.h b/aten/src/ATen/hip_compat/include/c10/cuda/impl/cuda_cmake_macros.h new file mode 100644 index 0000000000000..1dabc1510bd30 --- /dev/null +++ b/aten/src/ATen/hip_compat/include/c10/cuda/impl/cuda_cmake_macros.h @@ -0,0 +1,5 @@ +#pragma once + +// CMake-generated cuda_cmake_macros.h doesn't exist on HIP builds; forward +// to its hip equivalent so c10/cuda/CUDAMacros.h transitive-includes work. +#include diff --git a/aten/src/ATen/hip_compat/include/cuda.h b/aten/src/ATen/hip_compat/include/cuda.h new file mode 100644 index 0000000000000..704ffee5b8579 --- /dev/null +++ b/aten/src/ATen/hip_compat/include/cuda.h @@ -0,0 +1,5 @@ +#pragma once + +// CUDA driver API header; on HIP, forward to hip_runtime which exposes the +// equivalent driver entry points used by c10/cuda/CUDAException.h. +#include diff --git a/aten/src/ATen/hip_compat/include/cuda_runtime.h b/aten/src/ATen/hip_compat/include/cuda_runtime.h new file mode 100644 index 0000000000000..3da038d0afd11 --- /dev/null +++ b/aten/src/ATen/hip_compat/include/cuda_runtime.h @@ -0,0 +1,6 @@ +#pragma once + +// `cuda_runtime.h` is the catch-all CUDA SDK header; on HIP builds, forward +// to the equivalent. cuda_runtime_api.h carries the type/function aliases. +#include +#include diff --git a/aten/src/ATen/hip_compat/include/cuda_runtime_api.h b/aten/src/ATen/hip_compat/include/cuda_runtime_api.h new file mode 100644 index 0000000000000..16f5a70aceca4 --- /dev/null +++ b/aten/src/ATen/hip_compat/include/cuda_runtime_api.h @@ -0,0 +1,36 @@ +#pragma once + +// Drop-in shim so non-hipified files including on a +// HIP build compile. Forwards to hip_runtime_api.h and aliases the cuda* +// types/enums/functions used by c10/cuda/* headers to their hip* +// equivalents — mirrors the source-level rewrites that hipify performs. + +#include + +using cudaStream_t = hipStream_t; +using cudaError_t = hipError_t; +using cudaMemcpyKind = hipMemcpyKind; +using cudaStreamCaptureMode = hipStreamCaptureMode; +using cudaStreamCaptureStatus = hipStreamCaptureStatus; + +// Enum values are accessed via `cudaStreamCaptureStatus::cudaStreamCaptureStatusNone` +// (C++ scope resolution into the enum), which becomes +// `hipStreamCaptureStatus::cudaStreamCaptureStatusNone` after the type alias. +// The inner identifier needs to be a macro that substitutes to the hip-named +// value so the lookup hits the enum's actual member. +#define cudaSuccess hipSuccess +#define cudaStreamCaptureStatusNone hipStreamCaptureStatusNone +#define cudaStreamCaptureStatusActive hipStreamCaptureStatusActive +#define cudaStreamCaptureStatusInvalidated hipStreamCaptureStatusInvalidated + +#define cudaMemGetInfo hipMemGetInfo +#define cudaMallocAsync hipMallocAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamGetPriority hipStreamGetPriority +#define cudaStreamIsCapturing hipStreamIsCapturing +#define cudaStreamGetCaptureInfo hipStreamGetCaptureInfo +#define cudaThreadExchangeStreamCaptureMode hipThreadExchangeStreamCaptureMode +#define cudaGetLastError hipGetLastError +#define cudaGetErrorString hipGetErrorString +#define cudaDeviceGetStreamPriorityRange hipDeviceGetStreamPriorityRange diff --git a/aten/src/ATen/hip_compat/include/cudnn_frontend.h b/aten/src/ATen/hip_compat/include/cudnn_frontend.h new file mode 100644 index 0000000000000..7fc7804782bc4 --- /dev/null +++ b/aten/src/ATen/hip_compat/include/cudnn_frontend.h @@ -0,0 +1,62 @@ +#pragma once + +// Shim of cuDNN's `` for ROCm/hipDNN builds. Forwards +// `cudnn_frontend` symbols to `hipdnn_frontend` with cuDNN-style API shims +// layered on top (Graph::check_support(handle), HeurMode_t::A, etc.), and +// also forwards cuDNN-side `cudnnHandle_t`/`getCudnnHandle()`/ +// `AT_CUDNN_FRONTEND_CHECK` to hipDNN equivalents — so non-hipified cuDNN +// call sites compile unchanged on HIP. + +// TODO: drop this define once hipDNN exposes SDPA unconditionally and +// pytorch's LoadHIP.cmake propagates hipdnn_frontend's +// INTERFACE_COMPILE_DEFINITIONS. +#define HIPDNN_ENABLE_SDPA + +#include +#include +#include +#include + +namespace at::native::hipdnn_compat { + +using namespace hipdnn_frontend; + +namespace graph { +using namespace hipdnn_frontend::graph; + +class Graph : public hipdnn_frontend::graph::Graph { + public: + // cuDNN's check_support / build_plans take a handle; hipDNN's don't (the + // handle is bound at execute time). Add overloads that ignore the handle + // and forward to the no-arg APIs. + using hipdnn_frontend::graph::Graph::check_support; + using hipdnn_frontend::graph::Graph::build_plans; + auto check_support(hipdnnHandle_t /*handle*/) { return check_support(); } + auto build_plans(hipdnnHandle_t /*handle*/) { return build_plans(); } + + // cuDNN exposes a per-uid query via an out-parameter. hipDNN only offers + // a one-shot {uid -> shared_ptr} map; wrap it. + hipdnn_frontend::error_t query_tensor_attributes_of_uid( + int64_t uid, + hipdnn_frontend::graph::Tensor_attributes& attrs) const { + auto graph_tensors = getTensorsByUid(); + auto it = graph_tensors.find(uid); + if (it == graph_tensors.end()) { + return {hipdnn_frontend::error_code_t::ATTRIBUTE_NOT_SET, + "tensor uid not in graph"}; + } + attrs = *it->second; + return {hipdnn_frontend::error_code_t::OK, ""}; + } +}; + +} // namespace graph + +// Map cuDNN's HeurMode_t::A (recommended heuristic) to FALLBACK on hipDNN. +struct HeurMode_t { + static constexpr auto A = hipdnn_frontend::HeurMode_t::FALLBACK; +}; + +} // namespace at::native::hipdnn_compat + +namespace cudnn_frontend = at::native::hipdnn_compat; diff --git a/aten/src/ATen/native/cudnn/MHA.cpp b/aten/src/ATen/native/cudnn/MHA.cpp index f36b299e876f3..d8fc64fac8953 100644 --- a/aten/src/ATen/native/cudnn/MHA.cpp +++ b/aten/src/ATen/native/cudnn/MHA.cpp @@ -2,13 +2,14 @@ #include #include -#if AT_CUDNN_ENABLED() +#if defined(USE_HIPDNN) || AT_CUDNN_ENABLED() #include #endif -#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || \ - (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) || \ - (defined(CUDNN_FRONTEND_VERSION) && CUDNN_FRONTEND_VERSION < 10100) +#if (defined(USE_ROCM) && !defined(USE_HIPDNN)) || \ + (!defined(USE_ROCM) && (!AT_CUDNN_ENABLED() || \ + (defined(CUDNN_VERSION) && CUDNN_VERSION < 8900) || \ + (defined(CUDNN_FRONTEND_VERSION) && CUDNN_FRONTEND_VERSION < 10100))) namespace at { namespace native { @@ -122,10 +123,8 @@ void run_cudnn_SDP_bprop_nestedtensor( } // namespace native } // namespace at -#else // AT_CUDNN_ENABLED && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8900 -#include -#include -#include +#else +#include #include #include @@ -135,7 +134,6 @@ void run_cudnn_SDP_bprop_nestedtensor( #include #include -#include #include @@ -195,17 +193,60 @@ int roundup_power2(int dim) { return dim; } +// TODO: replace with a shared cuDNN/hipDNN dtype util once one exists. +inline fe::DataType_t to_fe_data_type(c10::ScalarType t) { + switch (t) { + case kHalf: + return fe::DataType_t::HALF; + case kBFloat16: + return fe::DataType_t::BFLOAT16; + case kFloat: + return fe::DataType_t::FLOAT; + case kDouble: + return fe::DataType_t::DOUBLE; + case kBool: + return fe::DataType_t::BOOLEAN; + case kInt: + return fe::DataType_t::INT32; + case kLong: + return fe::DataType_t::INT64; + default: + TORCH_CHECK(false, "cuDNN/hipDNN SDPA: unexpected tensor dtype ", t); + } +} + +// Asserts the runtime tensor's dim/stride/dtype match the graph's +// Tensor_attributes for the given UID. Used as a defensive check in +// run_cudnn_SDP_fprop / _bprop to catch cache-key bugs where a graph +// would otherwise be executed against tensors with different metadata. +static void check_tensor_matches_graph( + const fe::graph::Graph& graph, + int64_t uid, + const Tensor& t) { + fe::graph::Tensor_attributes attrs; + AT_CUDNN_FRONTEND_CHECK(graph.query_tensor_attributes_of_uid(uid, attrs)); + TORCH_CHECK(t.sizes() == IntArrayRef(attrs.get_dim())); + TORCH_CHECK(t.strides() == IntArrayRef(attrs.get_stride())); + TORCH_CHECK(to_fe_data_type(t.scalar_type()) == attrs.get_data_type()); +} + struct MHAParams { c10::DeviceIndex device_id; fe::DataType_t dataType; - std::array q_dim; - std::array k_dim; - std::array v_dim; - std::array q_stride; - std::array k_stride; - std::array v_stride; - std::array bias_dim; - std::array bias_stride; +#ifdef USE_HIPDNN + // hipDNN bakes the scale into the graph, so it must be part of the cache + // key. + float scaling_factor; +#endif + fe::DataType_t bias_dtype; // NOTE: on cuDNN this is always the same as `dataType` + std::array q_dim; + std::array k_dim; + std::array v_dim; + std::array q_stride; + std::array k_stride; + std::array v_stride; + std::array bias_dim; + std::array bias_stride; int64_t b; int64_t h; int64_t s_q; @@ -218,7 +259,7 @@ struct MHAParams { // might be redundant if we take 0 dim/stride // as signaling no-bias bool has_attn_bias; - bool use_ragged; + bool use_ragged; // NOTE: on hipDNN this is always false }; void setMHAParams( @@ -236,6 +277,9 @@ void setMHAParams( double dropout_probability, bool is_causal, bool return_softmaxstats, +#ifdef USE_HIPDNN + float scaling_factor, +#endif bool is_nested) { memset(¶ms, 0, sizeof(MHAParams)); params.device_id = at::cuda::current_device(); @@ -243,6 +287,9 @@ void setMHAParams( if (q.scalar_type() == kBFloat16) { params.dataType = fe::DataType_t::BFLOAT16; } +#ifdef USE_HIPDNN + params.scaling_factor = scaling_factor; +#endif params.b = b; params.h = h; params.d_qk = d_qk; @@ -294,6 +341,7 @@ void setMHAParams( } // uninit is OK as the struct is memset 0'd if (params.has_attn_bias) { + params.bias_dtype = to_fe_data_type(attn_bias.value().scalar_type()); std::copy( attn_bias.value().sizes().begin(), attn_bias.value().sizes().end(), @@ -320,6 +368,9 @@ struct MHACacheKeyWrapper : ParamsWrapper { double dropout_probability, bool is_causal, bool return_softmaxstats, +#ifdef USE_HIPDNN + float scaling_factor, +#endif bool is_nested) { setMHAParams( this->pod, @@ -336,6 +387,9 @@ struct MHACacheKeyWrapper : ParamsWrapper { dropout_probability, is_causal, return_softmaxstats, +#ifdef USE_HIPDNN + scaling_factor, +#endif is_nested); } }; @@ -459,15 +513,10 @@ std::unique_ptr build_graph( const Tensor& k, const Tensor& v, const std::optional& attn_bias, - Tensor& softmaxstats, - Tensor& o, - Tensor& dropoutseed, - Tensor& dropoutoffset, - cudnnHandle_t& handle) { - auto dtype = fe::DataType_t::HALF; - if (q.scalar_type() == kBFloat16) { - dtype = fe::DataType_t::BFLOAT16; - } + cudnnHandle_t& handle, + bool finalize, + bool use_ragged) { + auto dtype = to_fe_data_type(q.scalar_type()); auto mha_graph = std::make_unique(); // We're baking in float accumulation and scale types // in theory the graph may support other types, but they @@ -475,6 +524,21 @@ std::unique_ptr build_graph( mha_graph->set_io_data_type(dtype) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); + auto scaled_dot_product_flash_attention_options = + fe::graph::SDPA_attributes() + .set_name("CUDNN_SDPA") +#if defined(USE_HIPDNN) || CUDNN_FRONTEND_VERSION > 11200 + .set_generate_stats(return_softmaxstats) +#else + .set_is_inference(!return_softmaxstats) +#endif + .set_causal_mask(is_causal); + + // Scale is a constant attribute on hipDNN, and a tensor input on cuDNN. +#ifdef USE_HIPDNN + scaled_dot_product_flash_attention_options + .set_attn_scale_value(scaling_factor); +#else auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SCALE) @@ -483,17 +547,10 @@ std::unique_ptr build_graph( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); - auto scaled_dot_product_flash_attention_options = - fe::graph::SDPA_attributes() - .set_name("CUDNN_SDPA") -#if CUDNN_FRONTEND_VERSION <= 11200 - .set_is_inference(!return_softmaxstats) -#else - .set_generate_stats(return_softmaxstats) -#endif - .set_causal_mask(is_causal) + scaled_dot_product_flash_attention_options .set_attn_scale(attn_scale); - if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { +#endif + if (use_ragged) { auto SEQ_LEN_Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SEQ_LEN_Q) @@ -508,29 +565,29 @@ std::unique_ptr build_graph( .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); +#ifdef USE_HIPDNN + TORCH_CHECK(false, "ragged-in-dense SDPA is not supported on hipDNN"); +#else scaled_dot_product_flash_attention_options.set_seq_len_q(SEQ_LEN_Q_) .set_seq_len_kv(SEQ_LEN_KV_) .set_padding_mask(true); +#endif } if (dropout_probability != 0.0f) { + // Hardcode INT64 for droutout seed/offset; attention.cpp / attention.cu always + // allocates INT64 tensors. auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SEED) .set_name("Seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); + .set_data_type(fe::DataType_t::INT64)); auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(OFFSET) .set_name("Offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); + .set_data_type(fe::DataType_t::INT64)); scaled_dot_product_flash_attention_options.set_dropout( dropout_probability, seed, offset); } @@ -546,19 +603,31 @@ std::unique_ptr build_graph( .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) - .set_stride(attn_bias.value().strides().vec()))); + .set_stride(attn_bias.value().strides().vec()) + .set_data_type(to_fe_data_type( + attn_bias.value().scalar_type())))); } auto [O_, Stats] = mha_graph->sdpa(Q_, K_, V_, scaled_dot_product_flash_attention_options); O_->set_uid(O).set_output(true); + // Stats / O metadata is inferred from {b, h, s_q, d_v} + Q's layout. + // PyTorch's dispatch always allocates softmaxstats as + // at::empty({b, h, s_q, 1}) (contiguous {h*s_q, s_q, 1, 1} strides) and o + // via alloc_with_matching_layout(q, o, {b, h, s_q, d_v}) — so the inferred + // metadata equals what reading the actual tensors would produce. + // O sizes and strides are inferred (no `o` tensor at build time): O is + // {b, h, s_q, d_v} matching Q's layout (mirrors alloc_with_matching_layout). + std::vector o_sizes = {b, h, s_q, d_v}; + auto o_strides = compute_matching_strides(q.sizes(), q.strides(), o_sizes); if (Stats) { Stats->set_uid(LSE) .set_output(true) .set_data_type(fe::DataType_t::FLOAT) - .set_stride(softmaxstats.strides().vec()); + // Inferred: contiguous {b, h, s_q, 1} layout (matches at::empty). + .set_stride({h * s_q, s_q, 1, 1}); } - if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { + if (use_ragged) { auto RAG_Q_OFF_ = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(RAG_Q_OFF) @@ -594,14 +663,18 @@ std::unique_ptr build_graph( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); +#ifndef USE_HIPDNN O_->set_ragged_offset(RAG_O_OFF_); Q_->set_ragged_offset(RAG_Q_OFF_); K_->set_ragged_offset(RAG_K_OFF_); V_->set_ragged_offset(RAG_V_OFF_); +#else + TORCH_CHECK(false, "ragged-in-dense SDPA is not supported on hipDNN"); +#endif auto qsizevec = q.sizes().vec(); auto ksizevec = k.sizes().vec(); auto vsizevec = v.sizes().vec(); - auto osizevec = o.sizes().vec(); + auto osizevec = o_sizes; qsizevec[2] = roundup_power2(qsizevec[2]); ksizevec[2] = roundup_power2(ksizevec[2]); vsizevec[2] = roundup_power2(vsizevec[2]); @@ -618,8 +691,10 @@ std::unique_ptr build_graph( O_->set_dim(osizevec).set_stride( {INT_MAX, osizevec[3], osizevec[1] * osizevec[3], 1}); if (Stats) { +#ifndef USE_HIPDNN Stats->set_ragged_offset(RAG_STATS_OFF_); - auto statssizevec = softmaxstats.sizes().vec(); +#endif + std::vector statssizevec = {b, h, s_q, 1}; statssizevec[2] = roundup_power2(statssizevec[2]); Stats->set_dim(statssizevec); } @@ -630,19 +705,21 @@ std::unique_ptr build_graph( .set_stride(fixSizeOneDimStrideSDPA(k.sizes(), k.strides().vec())); V_->set_dim(v.sizes().vec()) .set_stride(fixSizeOneDimStrideSDPA(v.sizes(), v.strides().vec())); - O_->set_dim(o.sizes().vec()) - .set_stride(fixSizeOneDimStrideSDPA(o.sizes(), o.strides().vec())); + O_->set_dim(o_sizes) + .set_stride(fixSizeOneDimStrideSDPA(o_sizes, o_strides)); if (Stats) { - Stats->set_dim(softmaxstats.sizes().vec()); + Stats->set_dim({b, h, s_q, 1}); } } - AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); - AT_CUDNN_FRONTEND_CHECK( - mha_graph->create_execution_plans({fe::HeurMode_t::A})); - AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); + if (finalize) { + AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); + AT_CUDNN_FRONTEND_CHECK( + mha_graph->create_execution_plans({fe::HeurMode_t::A})); + AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); + } return mha_graph; } @@ -671,6 +748,9 @@ std::unique_ptr build_graph_nestedtensor( Tensor& dropoutseed, Tensor& dropoutoffset, cudnnHandle_t& handle) { +#ifdef USE_HIPDNN + TORCH_CHECK(false, "nested-tensor SDPA is not supported on hipDNN"); +#else auto dtype = fe::DataType_t::HALF; if (q.scalar_type() == kBFloat16) { dtype = fe::DataType_t::BFLOAT16; @@ -852,6 +932,7 @@ std::unique_ptr build_graph_nestedtensor( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); return mha_graph; +#endif } std::unique_ptr build_graph_backward( @@ -868,19 +949,10 @@ std::unique_ptr build_graph_backward( const Tensor& k, const Tensor& v, const std::optional& attn_bias, - const Tensor& o, - const Tensor& dO, - const Tensor& softmaxstats, - Tensor& dQ, - Tensor& dK, - Tensor& dV, - const Tensor& dropoutseed, - const Tensor& dropoutoffset, - cudnnHandle_t& handle) { - auto dtype = fe::DataType_t::HALF; - if (q.scalar_type() == kBFloat16) { - dtype = fe::DataType_t::BFLOAT16; - } + cudnnHandle_t& handle, + bool finalize, + bool use_ragged) { + auto dtype = to_fe_data_type(q.scalar_type()); auto mha_graph = std::make_unique(); // We're baking in float accumulation and scale types // in theory the graph may support other types, but they @@ -888,6 +960,7 @@ std::unique_ptr build_graph_backward( mha_graph->set_io_data_type(dtype) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); +#ifndef USE_HIPDNN auto attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SCALE) @@ -896,11 +969,16 @@ std::unique_ptr build_graph_backward( .set_stride({1, 1, 1, 1}) .set_is_pass_by_value(true) .set_data_type(fe::DataType_t::FLOAT)); +#endif auto sdpa_backward_options = fe::graph::SDPA_backward_attributes() .set_name("CUDNN_SDPA_BACKWARD") .set_causal_mask(is_causal) +#ifdef USE_HIPDNN + .set_attn_scale_value(scaling_factor); +#else .set_attn_scale(attn_scale); - if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { +#endif + if (use_ragged) { auto SEQ_LEN_Q_ = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(SEQ_LEN_Q) @@ -915,9 +993,13 @@ std::unique_ptr build_graph_backward( .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); +#ifndef USE_HIPDNN sdpa_backward_options.set_seq_len_q(SEQ_LEN_Q_) .set_seq_len_kv(SEQ_LEN_KV_) .set_padding_mask(true); +#else + TORCH_CHECK(false, "ragged-in-dense SDPA is not supported on hipDNN"); +#endif } auto Q_ = mha_graph->tensor( @@ -932,7 +1014,9 @@ std::unique_ptr build_graph_backward( .set_uid(BIAS) .set_name("bias") .set_dim(attn_bias.value().sizes().vec()) - .set_stride(attn_bias.value().strides().vec()))); + .set_stride(attn_bias.value().strides().vec()) + .set_data_type(to_fe_data_type( + attn_bias.value().scalar_type())))); } if (dropout_probability != 0.0f) { auto seed = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -940,27 +1024,24 @@ std::unique_ptr build_graph_backward( .set_name("Seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutseed.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); + .set_data_type(fe::DataType_t::INT64)); auto offset = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(OFFSET) .set_name("Offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_data_type( - dropoutoffset.dtype() == kInt - ? fe::DataType_t::INT32 - : fe::DataType_t::INT64)); + .set_data_type(fe::DataType_t::INT64)); sdpa_backward_options.set_dropout(dropout_probability, seed, offset); } auto O_ = mha_graph->tensor( fe::graph::Tensor_attributes().set_uid(O).set_name("O")); + // Infer the standard contiguous {b,h,s_q,1} Stats layout (matches + // at::empty allocation). auto Stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(LSE) .set_name("Stats") - .set_stride(softmaxstats.strides().vec()) + .set_dim({b, h, s_q, 1}) + .set_stride({h * s_q, s_q, 1, 1}) .set_data_type(fe::DataType_t::FLOAT)); auto Do = mha_graph->tensor( fe::graph::Tensor_attributes().set_uid(DO).set_name("DO")); @@ -969,7 +1050,12 @@ std::unique_ptr build_graph_backward( Dq->set_uid(DQ).set_output(true); Dk->set_uid(DK).set_output(true); Dv->set_uid(DV).set_output(true); - if (use_ragged_in_dense(q, k, v, o, attn_bias.has_value())) { + // O/dO sizes and strides are inferred (no `o`/`dO` tensor at build time): + // O/dO is {b, h, s_q, d_v} matching Q's layout. dQ/dK/dV share Q/K/V's + // layout (PyTorch's dispatch ensures matching strides). + std::vector o_sizes = {b, h, s_q, d_v}; + auto o_strides = compute_matching_strides(q.sizes(), q.strides(), o_sizes); + if (use_ragged) { auto RAG_Q_OFF_ = mha_graph->tensor(fe::graph::Tensor_attributes() .set_uid(RAG_Q_OFF) @@ -1005,6 +1091,7 @@ std::unique_ptr build_graph_backward( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::INT32)); +#ifndef USE_HIPDNN O_->set_ragged_offset(RAG_O_OFF_); Q_->set_ragged_offset(RAG_Q_OFF_); K_->set_ragged_offset(RAG_K_OFF_); @@ -1013,10 +1100,13 @@ std::unique_ptr build_graph_backward( Dk->set_ragged_offset(RAG_K_OFF_); Dv->set_ragged_offset(RAG_V_OFF_); Do->set_ragged_offset(RAG_O_OFF_); +#else + TORCH_CHECK(false, "ragged-in-dense SDPA is not supported on hipDNN"); +#endif auto qsizevec = q.sizes().vec(); auto ksizevec = k.sizes().vec(); auto vsizevec = v.sizes().vec(); - auto osizevec = o.sizes().vec(); + auto osizevec = o_sizes; qsizevec[2] = roundup_power2(qsizevec[2]); ksizevec[2] = roundup_power2(ksizevec[2]); vsizevec[2] = roundup_power2(vsizevec[2]); @@ -1041,28 +1131,32 @@ std::unique_ptr build_graph_backward( Do->set_dim(osizevec).set_stride( {INT_MAX, osizevec[3], osizevec[1] * osizevec[3], 1}); +#ifndef USE_HIPDNN Stats->set_ragged_offset(RAG_STATS_OFF_); - auto statssizevec = softmaxstats.sizes().vec(); +#endif + std::vector statssizevec = {b, h, s_q, 1}; statssizevec[2] = roundup_power2(statssizevec[2]); Stats->set_dim(statssizevec); } else { - O_->set_dim(o.sizes().vec()).set_stride(o.strides().vec()); Q_->set_dim(q.sizes().vec()).set_stride(q.strides().vec()); K_->set_dim(k.sizes().vec()).set_stride(k.strides().vec()); V_->set_dim(v.sizes().vec()).set_stride(v.strides().vec()); - Dq->set_dim(dQ.sizes().vec()).set_stride(dQ.strides().vec()); - Dk->set_dim(dK.sizes().vec()).set_stride(dK.strides().vec()); - Dv->set_dim(dV.sizes().vec()).set_stride(dV.strides().vec()); - Do->set_dim(dO.sizes().vec()).set_stride(dO.strides().vec()); - Stats->set_dim(softmaxstats.sizes().vec()); + O_->set_dim(o_sizes).set_stride(o_strides); + Dq->set_dim(q.sizes().vec()).set_stride(q.strides().vec()); + Dk->set_dim(k.sizes().vec()).set_stride(k.strides().vec()); + Dv->set_dim(v.sizes().vec()).set_stride(v.strides().vec()); + Do->set_dim(o_sizes).set_stride(o_strides); + Stats->set_dim({b, h, s_q, 1}); } - AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); - AT_CUDNN_FRONTEND_CHECK( - mha_graph->create_execution_plans({fe::HeurMode_t::A})); - AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); - AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); + if (finalize) { + AT_CUDNN_FRONTEND_CHECK(mha_graph->validate()); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_operation_graph(handle)); + AT_CUDNN_FRONTEND_CHECK( + mha_graph->create_execution_plans({fe::HeurMode_t::A})); + AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); + AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); + } return mha_graph; } @@ -1093,6 +1187,9 @@ std::unique_ptr build_graph_backward_nestedtensor( const Tensor& dropoutseed, const Tensor& dropoutoffset, cudnnHandle_t& handle) { +#ifdef USE_HIPDNN + TORCH_CHECK(false, "nested-tensor SDPA backward is not supported on hipDNN"); +#else auto dtype = fe::DataType_t::HALF; if (q.scalar_type() == kBFloat16) { dtype = fe::DataType_t::BFLOAT16; @@ -1305,6 +1402,7 @@ std::unique_ptr build_graph_backward_nestedtensor( AT_CUDNN_FRONTEND_CHECK(mha_graph->check_support(handle)); AT_CUDNN_FRONTEND_CHECK(mha_graph->build_plans(handle)); return mha_graph; +#endif } void run_cudnn_SDP_fprop( @@ -1367,14 +1465,16 @@ void run_cudnn_SDP_fprop( } } - const auto dprops = at::cuda::getCurrentDeviceProperties(); auto _dropoutseed = dropoutseed; auto _dropoutoffset = dropoutoffset; +#ifndef USE_HIPDNN + const auto dprops = at::cuda::getCurrentDeviceProperties(); // cuDNN dropout bug requires these to be in int64 if (dprops->major == 10 && dprops->minor == 0) { _dropoutseed = dropoutseed.to(kLong); _dropoutoffset = dropoutoffset.to(kLong); } +#endif cudnnHandle_t handle = getCudnnHandle(); @@ -1395,9 +1495,12 @@ void run_cudnn_SDP_fprop( dropout_probability, is_causal, return_softmaxstats, - false); - auto [cache_it, not_found] = getMHAGraphCache_().try_emplace(key, nullptr); - if (not_found) { +#ifdef USE_HIPDNN + scaling_factor, +#endif + /*is_nested=*/false); + auto cache_it = getMHAGraphCache_().try_emplace(key, nullptr).first; + if (cache_it->second == nullptr) { cache_it->second = build_graph( b, h, @@ -1413,18 +1516,35 @@ void run_cudnn_SDP_fprop( k, v, attn_bias, - softmaxstats, - o, - _dropoutseed, - _dropoutoffset, - handle); + handle, + /*finalize=*/true, + use_ragged); } const fe::graph::Graph& mha_graph = *cache_it->second; + // Validate that the runtime tensor metadata still matches the + // cached graph. + check_tensor_matches_graph(mha_graph, Q, q); + check_tensor_matches_graph(mha_graph, K, k); + check_tensor_matches_graph(mha_graph, V, v); + check_tensor_matches_graph(mha_graph, O, o); + if (return_softmaxstats) { + check_tensor_matches_graph(mha_graph, LSE, softmaxstats); + } + if (attn_bias.has_value()) { + check_tensor_matches_graph(mha_graph, BIAS, attn_bias.value()); + } + if (dropout_probability != 0.0f) { + check_tensor_matches_graph(mha_graph, SEED, _dropoutseed); + check_tensor_matches_graph(mha_graph, OFFSET, _dropoutoffset); + } std::unordered_map variant_pack = { {Q, q.mutable_data_ptr()}, {K, k.mutable_data_ptr()}, {V, v.mutable_data_ptr()}, +#ifndef USE_HIPDNN + // hipDNN bakes the scale into the graph; no SCALE tensor. {SCALE, &scaling_factor}, +#endif {O, o.mutable_data_ptr()}}; if (return_softmaxstats) { variant_pack[LSE] = softmaxstats.mutable_data_ptr(); @@ -1447,7 +1567,12 @@ void run_cudnn_SDP_fprop( variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr(); } } +#if defined(USE_HIPDNN) || CUDNN_FRONTEND_VERSION >= 10700 + int64_t workspace_size = 0; + AT_CUDNN_FRONTEND_CHECK(mha_graph.get_workspace_size(workspace_size)); +#else auto workspace_size = mha_graph.get_workspace_size(); +#endif auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( @@ -1477,6 +1602,9 @@ void run_cudnn_SDP_fprop_nestedtensor( Tensor& o, Tensor& dropoutseed, Tensor& dropoutoffset) { +#ifdef USE_HIPDNN + TORCH_CHECK(false, "nested-tensor SDPA is not supported on hipDNN"); +#else cudnnHandle_t handle = getCudnnHandle(); // do nothing if we got 0-element tensors if (!q.numel() || !k.numel() || !v.numel()) { @@ -1574,6 +1702,7 @@ void run_cudnn_SDP_fprop_nestedtensor( c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK( mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good()); +#endif } void run_cudnn_SDP_bprop( @@ -1606,14 +1735,16 @@ void run_cudnn_SDP_bprop( Tensor seqlen_q, seqlen_kv; Tensor rag_off_q, rag_off_k, rag_off_v, rag_off_o, rag_off_lse; - auto dprops = at::cuda::getCurrentDeviceProperties(); auto _dropoutseed = dropoutseed; auto _dropoutoffset = dropoutoffset; +#ifndef USE_HIPDNN + auto dprops = at::cuda::getCurrentDeviceProperties(); // cuDNN dropout bug requires these to be in int64 if (dprops->major == 10 && dprops->minor == 0) { _dropoutseed = dropoutseed.to(kLong); _dropoutoffset = dropoutoffset.to(kLong); } +#endif Tensor dO_ = dO; // cuDNN < 9.5.1 assumes gradOutput has same strides as Output @@ -1669,10 +1800,13 @@ void run_cudnn_SDP_bprop( dropout_probability, is_causal, true, - false); - auto [cache_it, not_found] = - getMHAGraphBackwardCache_().try_emplace(key, nullptr); - if (not_found) { +#ifdef USE_HIPDNN + scaling_factor, +#endif + /*is_nested=*/false); + auto cache_it = + getMHAGraphBackwardCache_().try_emplace(key, nullptr).first; + if (cache_it->second == nullptr) { cache_it->second = build_graph_backward( b, h, @@ -1687,17 +1821,27 @@ void run_cudnn_SDP_bprop( k, v, attn_bias, - o, - dO_, - softmaxstats, - dQ, - dK, - dV, - _dropoutseed, - _dropoutoffset, - handle); + handle, + /*finalize=*/true, + use_ragged_in_dense(q, k, v, o, attn_bias.has_value())); } const fe::graph::Graph& mha_graph = *cache_it->second; + check_tensor_matches_graph(mha_graph, Q, q); + check_tensor_matches_graph(mha_graph, K, k); + check_tensor_matches_graph(mha_graph, V, v); + check_tensor_matches_graph(mha_graph, O, o); + check_tensor_matches_graph(mha_graph, DO, dO_); + check_tensor_matches_graph(mha_graph, LSE, softmaxstats); + check_tensor_matches_graph(mha_graph, DQ, dQ); + check_tensor_matches_graph(mha_graph, DK, dK); + check_tensor_matches_graph(mha_graph, DV, dV); + if (attn_bias.has_value()) { + check_tensor_matches_graph(mha_graph, BIAS, attn_bias.value()); + } + if (dropout_probability != 0.0f) { + check_tensor_matches_graph(mha_graph, SEED, _dropoutseed); + check_tensor_matches_graph(mha_graph, OFFSET, _dropoutoffset); + } std::unordered_map variant_pack = { // inputs @@ -1711,7 +1855,10 @@ void run_cudnn_SDP_bprop( {DQ, dQ.mutable_data_ptr()}, {DK, dK.mutable_data_ptr()}, {DV, dV.mutable_data_ptr()}, - {SCALE, &scaling_factor}}; +#ifndef USE_HIPDNN + {SCALE, &scaling_factor} +#endif + }; if (dropout_probability != 0.0f) { variant_pack[SEED] = _dropoutseed.mutable_data_ptr(); variant_pack[OFFSET] = _dropoutoffset.mutable_data_ptr(); @@ -1729,7 +1876,12 @@ void run_cudnn_SDP_bprop( variant_pack[RAG_LSE_OFF] = rag_off_lse.mutable_data_ptr(); } +#if defined(USE_HIPDNN) || CUDNN_FRONTEND_VERSION >= 10700 + int64_t workspace_size = 0; + AT_CUDNN_FRONTEND_CHECK(mha_graph.get_workspace_size(workspace_size)); +#else auto workspace_size = mha_graph.get_workspace_size(); +#endif auto workspace_ptr = c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size); TORCH_CHECK(!workspace_size || workspace_ptr.get()); @@ -1763,6 +1915,9 @@ void run_cudnn_SDP_bprop_nestedtensor( Tensor& dV, const Tensor& dropoutseed, const Tensor& dropoutoffset) { +#ifdef USE_HIPDNN + TORCH_CHECK(false, "nested-tensor SDPA is not supported on hipDNN"); +#else // do nothing if we got 0-element tensors if (!q.numel() || !k.numel() || !v.numel() || !o.numel() || !dO.numel() || !softmaxstats.numel()) { @@ -1879,6 +2034,89 @@ void run_cudnn_SDP_bprop_nestedtensor( TORCH_CHECK(!workspace_size || workspace_ptr.get()); TORCH_CHECK( mha_graph.execute(handle, variant_pack, workspace_ptr.get()).is_good()); +#endif +} + +using MHASupportCache = std::unordered_map< + MHACacheKeyWrapper, + bool, + ParamsWrapperHash>; +static MHASupportCache& getMHASupportCache_() { + thread_local MHASupportCache instance; + return instance; +} + + +bool check_cudnn_sdpa_support(sdp::sdp_params const& params, bool debug) { +#ifndef USE_HIPDNN + return true; +#else + const Tensor& q = params.query; + const Tensor& k = params.key; + const Tensor& v = params.value; + // Concrete sizes are required to query the backend. + if (q.unsafeGetTensorImpl()->has_symbolic_sizes_strides() || + k.unsafeGetTensorImpl()->has_symbolic_sizes_strides() || + v.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) { + if (debug) { + TORCH_WARN("hipDNN SDPA: static shapes are required"); + } + return false; + } + const int64_t b = q.size(0); + const int64_t h = q.size(1); + const int64_t s_q = q.size(2); + const int64_t s_kv = k.size(2); + const int64_t d_qk = q.size(3); + const int64_t d_v = v.size(3); + const float scaling_factor = + sdp::calculate_scale(q, params.scale).expect_float(); + const bool return_softmaxstats = sdp::input_requires_grad(params); + + // Mirror attention.cpp's bias rank handling so the cache key matches what + // run_cudnn_SDP_fprop sees at dispatch time. + std::optional bias = params.attn_mask; + if (bias.has_value()) { + const auto rank = bias.value().dim(); + TORCH_CHECK( + rank == 2 || rank == 3 || rank == 4, + "hipDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", + rank, + "D"); + const int64_t h_bias = rank == 4 ? bias.value().size(1) : 1; + bias = bias.value().expand({b, h_bias, s_q, s_kv}); + } + + cudnnHandle_t handle = getCudnnHandle(); + + MHACacheKeyWrapper key( + b, h, s_q, s_kv, d_qk, d_v, q, k, v, bias, params.dropout, + params.is_causal, return_softmaxstats, scaling_factor, + /*is_nested=*/false); + auto [it, not_probed] = getMHASupportCache_().try_emplace(key, false); + if (not_probed) { + // NOTE: We assume use_ragged=false here as hipDNN doesn't support them currently. + auto fwd_graph = build_graph( + b, h, s_q, s_kv, d_qk, d_v, scaling_factor, return_softmaxstats, + params.is_causal, params.dropout, q, k, v, bias, handle, + /*finalize=*/false, /*use_ragged=*/false); + it->second = fwd_graph->is_supported_ext(handle).is_good(); + // Backward support probe. Only needed when grad is required. + if (it->second && return_softmaxstats) { + auto bwd_graph = build_graph_backward( + b, h, s_q, s_kv, d_qk, d_v, scaling_factor, params.is_causal, + params.dropout, q, k, v, bias, handle, + /*finalize=*/false, /*use_ragged=*/false); + it->second = bwd_graph->is_supported_ext(handle).is_good(); + } + } + if (!it->second && debug) { + TORCH_WARN( + "hipDNN SDPA: no engine available for the given input configuration. " + "Set HIPDNN_LOG_LEVEL=info for details."); + } + return it->second; +#endif // USE_HIPDNN } } // namespace at::native diff --git a/aten/src/ATen/native/cudnn/MHA.h b/aten/src/ATen/native/cudnn/MHA.h index 620abc1aa0a8e..2c853d54a1eff 100644 --- a/aten/src/ATen/native/cudnn/MHA.h +++ b/aten/src/ATen/native/cudnn/MHA.h @@ -1,6 +1,8 @@ #pragma once #include +#include + namespace at::native { void run_cudnn_SDP_fprop( @@ -97,4 +99,9 @@ void run_cudnn_SDP_bprop_nestedtensor( const Tensor& dropoutseed, const Tensor& dropoutoffset); +// Query backend to determine if graph configuration is supported. +// Matches the constraint-function signature so it can drop into the +// can_use_cudnn_attention check chain. +bool check_cudnn_sdpa_support(sdp::sdp_params const& params, bool debug); + } // namespace at::native diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 74e2838851880..2d07785d73888 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -751,9 +751,15 @@ Tensor scaled_dot_product_attention( } const auto query_device_type = query_.device().type(); const auto backend = static_cast(choice_int); - auto attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); switch (backend) { case SDPBackend::cudnn_attention: { + // Skip the bool attention mask conversion on ROCm, as hipDNN handles + // bool masks itself. +#if defined(USE_ROCM) + auto attn_mask = attn_mask_; +#else + auto attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); +#endif bool compute_logsumexp = should_compute_logsumexp(query_, key, value); auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention( query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, false /*return_debug_mask*/, scale); @@ -774,10 +780,12 @@ Tensor scaled_dot_product_attention( return post_process_flash_output(std::get<0>(out_lse_softmax), og_size); } // For the CPU case we do not need to pad the last dim + auto attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); return std::get<0>(at::_scaled_dot_product_flash_attention_for_cpu( query_, key, value, dropout_p, is_causal, attn_mask, scale)); } case SDPBackend::efficient_attention: { + auto attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); bool compute_logsumexp = should_compute_logsumexp(query_, key, value); if (attn_mask.has_value()) { attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);; @@ -787,11 +795,13 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_and_lse); } case SDPBackend::overrideable: { + auto attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); auto out_lse_softmax = at::_scaled_dot_product_fused_attention_overrideable( query_, key, value, attn_mask, dropout_p, is_causal, false /*return_debug_mask*/, scale); return std::get<0>(out_lse_softmax); } case SDPBackend::math: { + auto attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad(); if (query_device_type == c10::kMPS && !(at::GradMode::is_enabled() && any_inputs_require_grad)) { return std::get<0>(at::_scaled_dot_product_attention_math_for_mps( diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 79b5df3f302bb..d2e3b670dceaf 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -34,6 +34,8 @@ #define USE_ROCM_ATTENTION 0 #endif +#include + // Avoid potential compiler -Wall -Werror complains undefined macro #ifndef AOTRITON_VERSION_MINOR #define AOTRITON_VERSION_MINOR 0 @@ -76,6 +78,10 @@ bool priority_order_init_ = false; // TODO(eqy): more benchmarking to determine whether this should include sm86/89 // Needs to be kept in-sync with test_fused_chocie in test_transformers.py bool check_prefer_cudnn_attention() { + static const bool force_prefer = c10::utils::check_env("TORCH_CUDNN_SDPA_PREFERRED") == true; + if (force_prefer) { + return true; + } static const bool prefer_cudnn = c10::utils::check_env("TORCH_CUDNN_SDPA_DEPRIORITIZED") != true; if (!prefer_cudnn) { return false; @@ -111,6 +117,17 @@ std::array priority_order(sdp_params const& params) { static_cast(at::SDPBackend::math)}; at::globalContext().setSDPPriorityOrder(cudnn_order); } +#if USE_ROCM + else { + // On ROCm, default to hipDNN above math fallback. + const std::vector rocm_order = {static_cast(at::SDPBackend::flash_attention), + static_cast(at::SDPBackend::efficient_attention), + static_cast(at::SDPBackend::cudnn_attention), + static_cast(at::SDPBackend::math), + static_cast(at::SDPBackend::overrideable)}; + at::globalContext().setSDPPriorityOrder(rocm_order); + } +#endif } return at::globalContext().sDPPriorityOrder(); } @@ -735,15 +752,49 @@ bool check_cudnn_deterministic(const sdp_params& params, bool debug) { return true; } +#if defined(USE_HIPDNN) +bool check_hipdnn_enabled(sdp_params const& params, bool debug) { + if (!at::globalContext().userEnabledHipdnn()) { + if (debug) { + TORCH_WARN("hipDNN is not enabled. Set torch.backends.hipdnn.enabled = True"); + } + return false; + } + return true; +} + +bool check_no_nested_inputs_hipdnn(sdp_params const& params, bool debug) { + if (has_for_nested_inputs(params)) { + if (debug) { + TORCH_WARN("hipDNN SDPA does not support nested tensors."); + } + return false; + } + return true; +} + +bool check_dtypes_hipdnn(sdp_params const& params, bool debug) { + constexpr auto hipdnn_dtypes = c10::array_of( + at::kHalf, at::kBFloat16, at::kFloat, at::kDouble); + return check_tensor_dtype(params, hipdnn_dtypes, debug); +} +#endif // USE_HIPDNN + } // namespace bool can_use_cudnn_attention(const sdp_params& params, bool debug) { -#if defined(USE_ROCM) || !AT_CUDNN_ENABLED() || !defined(CUDNN_VERSION) +#if defined(USE_ROCM) && !defined(USE_HIPDNN) + if (debug) { + TORCH_WARN("Torch was not compiled with hipDNN."); + } + return false; +#elif !defined(USE_ROCM) && (!AT_CUDNN_ENABLED() || !defined(CUDNN_VERSION)) if (debug) { TORCH_WARN("Torch was not compiled with cuDNN attention."); } return false; -#endif +#else + #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use cuDNN Attention (< v9.0.0)"); @@ -759,31 +810,43 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { return false; } #endif - // Define gate functions that determine if a flash kernel can be ran - // Replace with std::to_array when we migrate to c++20 + + // Common constraint chain for cuDNN and hipDNN. Replace c10::array_of with + // std::to_array when we migrate to c++20. constexpr auto general_constraints = c10::array_of( check_runtime_disabled_cudnn, +#ifdef USE_HIPDNN + check_hipdnn_enabled, + check_no_nested_inputs_hipdnn, + check_dtypes_hipdnn, +#else check_for_nested_inputs, + check_cudnn_hardware_support, + check_cudnn_dropout, + check_dtypes_low_precision, +#endif check_all_tensors_on_device, check_tensor_shapes, check_cudnn_deterministic, - check_dtypes_low_precision, check_attn_mask_shape, - check_cudnn_hardware_support, - check_cudnn_dropout - ); + at::native::check_cudnn_sdpa_support + ); for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { return false; } } + + // Dense-only constraints. hipDNN already rejects nested inputs above. constexpr auto dense_constraints = c10::array_of( +#ifndef USE_HIPDNN check_nonzero_sequence_lengths_dense, check_last_dim_stride_equals_1_dense, - check_batch_size_and_num_heads_dense, - check_cudnn_tensor_shapes + check_cudnn_tensor_shapes, +#endif + check_batch_size_and_num_heads_dense ); if (has_only_dense_inputs(params)) { @@ -793,6 +856,8 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { } } } +#endif + return true; } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 171e5446f5ed6..2b5f67e2ab5fb 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1874,6 +1874,12 @@ if(USE_ROCM) # Since PyTorch files contain HIP headers, this is also needed to capture the includes. # ROCM_INCLUDE_DIRS is defined in LoadHIP.cmake and added as SYSTEM includes in cmake/Dependencies.cmake target_include_directories(torch_hip PRIVATE ${Caffe2_HIP_INCLUDE}) + # Shim headers (cuda_runtime_api.h → hip_runtime_api.h, cudnn_frontend.h → + # hipdnn_frontend, ATen/cudnn/Handle.h → ATen/hipdnn/Handle.h, etc.) so + # non-hipified files compile against cuda-named APIs on HIP builds without + # source-level rewriting. Use BEFORE so these shadow the real ATen headers. + target_include_directories(torch_hip BEFORE PRIVATE + ${PROJECT_SOURCE_DIR}/aten/src/ATen/hip_compat/include) target_include_directories(torch_hip INTERFACE $) endif() diff --git a/test/test_transformers.py b/test/test_transformers.py index ced9b0133e11d..e833f8ae77b82 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3145,6 +3145,43 @@ def test_cudnn_attention_broken_166211(self): self.assertFalse(dk.isnan().any()) self.assertFalse(dv.isnan().any()) + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") + def test_cudnn_attention_bool_mask(self): + # Simple bool attn_mask test. + q = torch.randn(2, 4, 8, 16, dtype=torch.bfloat16, device='cuda') + k = torch.randn(2, 4, 8, 16, dtype=torch.bfloat16, device='cuda') + v = torch.randn(2, 4, 8, 16, dtype=torch.bfloat16, device='cuda') + + attn_mask = torch.tril(torch.ones(8, 8, dtype=torch.bool, device='cuda')) + + with sdpa_kernel(SDPBackend.MATH): + out_math = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask) + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + out_cudnn = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask) + + self.assertEqual(out_math, out_cudnn, atol=5e-3, rtol=3e-3) + + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") + def test_cudnn_attention_mismatched_mask_dtype(self): + # Float attn_mask with a dtype different from Q (bf16 Q, fp32 mask). + q = torch.randn(2, 4, 8, 16, dtype=torch.bfloat16, device='cuda') + k = torch.randn(2, 4, 8, 16, dtype=torch.bfloat16, device='cuda') + v = torch.randn(2, 4, 8, 16, dtype=torch.bfloat16, device='cuda') + + bool_mask = torch.tril(torch.ones(8, 8, dtype=torch.bool, device='cuda')) + attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(torch.float32) + + with sdpa_kernel(SDPBackend.MATH): + out_math = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask) + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + out_cudnn = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask) + + self.assertEqual(out_math, out_cudnn, atol=5e-3, rtol=3e-3) + @skipIfRocm @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_mask_broken_177842(self): @@ -3222,6 +3259,15 @@ def test_cudnn_attention_mask_broken_177842(self): ) self.assertEqual(attn_output_math, attn_output_cudnn, atol=5e-3, rtol=3e-3) + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + def test_cudnn_attention_runtime_disabled(self, device): + q = torch.empty(2, 8, 128, 64, dtype=torch.bfloat16, device=device) + params = torch.backends.cuda.SDPAParams(q, q, q, None, 0.0, False, False) + with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]): + self.assertTrue(torch.backends.cuda.can_use_cudnn_attention(params)) + with sdpa_kernel([SDPBackend.MATH]): + self.assertFalse(torch.backends.cuda.can_use_cudnn_attention(params)) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]): From e0226d9dfe0aa7aab9a3e88081969879493cf1e2 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Thu, 16 Apr 2026 10:48:20 -0700 Subject: [PATCH 09/10] Update tests to reflect cudnn backend being available on ROCM. --- test/test_transformers.py | 44 ++++++++++++++++++++------ torch/testing/_internal/common_cuda.py | 4 ++- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/test/test_transformers.py b/test/test_transformers.py index e833f8ae77b82..891891f1e961b 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -53,6 +53,7 @@ PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, PLATFORM_SUPPORTS_CK_SDPA, + TEST_HIPDNN, tf32_on_and_off, tf32_enabled, ) @@ -62,6 +63,10 @@ if TEST_FAIRSEQ: import fairseq.models.transformer as fairseq_transformer +# TODO: drop once an env-var toggle replaces the Python-level enable. +if torch.backends.hipdnn.is_available(): + torch.backends.hipdnn.set_flags(True) + SdpaShape = namedtuple('Sdpa_Shape', ['batch', 'num_heads', 'seq_len', 'head_dim']) Tolerances = namedtuple('Tolerances', ['atol', 'rtol']) @@ -1694,7 +1699,10 @@ def test_invalid_last_dim_stride(self, device, kernel: SDPBackend): size = SdpaShape(2, 2, 8, 8) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) q.as_strided_(size, [2, 2, 2, 2]) - with self.assertWarnsRegex(UserWarning, "All fused kernels require the last dimension of the input to have stride 1."): + expected_warning = "All fused kernels require the last dimension of the input to have stride 1." + if kernel == SDPBackend.CUDNN_ATTENTION and TEST_HIPDNN: + expected_warning = "hipDNN SDPA: no engine available for the given input configuration." + with self.assertWarnsRegex(UserWarning, expected_warning): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) @@ -1768,8 +1776,13 @@ def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend): size = SdpaShape(2, 2, 3, 16) make_tensor = partial(torch.rand, device=device, dtype=torch.float64) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + run_fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) + if kernel == SDPBackend.CUDNN_ATTENTION and TEST_WITH_ROCM: + # hipDNN accepts f64 inputs + run_fn() + else: + self.assertRaises(RuntimeError, run_fn) + @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") @@ -3043,7 +3056,10 @@ def test_cudnn_attention_broadcast_stride_zero(self, device, dtype): out_ref = F.scaled_dot_product_attention( q_bc.contiguous(), k_bc.contiguous(), v_bc.contiguous(), is_causal=True ) - torch.testing.assert_close(out, out_ref, atol=3e-3, rtol=3e-3) + atol, rtol = 3e-3, 3e-3 + if TEST_WITH_ROCM and dtype == torch.bfloat16: + atol, rtol = 2e-2, 2e-2 + torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") @unittest.skipIf( @@ -3112,6 +3128,7 @@ def test_cudnn_attention_seqlen1_dropout_heuristic(self): out = torch.nn.functional.scaled_dot_product_attention(q, q, q, dropout_p=0.5) out.backward(grad) + @skipIfRocm # This tests cudnn-specific rounding behaviour, skip on ROCM @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_low_dropout(self): q = torch.randn(2, 8, 128, 128, dtype=torch.half, device='cuda') @@ -3615,10 +3632,14 @@ def test_fused_sdp_choice(self, device, type: str): if "cuda" in str(device): device_capability = torch.cuda.get_device_capability() prefer_cudnn = "TORCH_CUDNN_SDPA_PREFERRED" not in os.environ or bool(os.environ["TORCH_CUDNN_SDPA_PREFERRED"]) - # cuDNN prioritization requires cuDNN > 9.15.0 (91500) per sdp_utils.cpp:83 - cudnn_version = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 0 - is_hopper_or_newer = device_capability and (device_capability[0] == 9 or device_capability[0] == 10) - prefer_cudnn = prefer_cudnn and is_hopper_or_newer and cudnn_version > 91500 + if TEST_WITH_ROCM: + # hipDNN is not prioritized + prefer_cudnn = False + else: + # cuDNN prioritization requires cuDNN > 9.15.0 (91500) per sdp_utils.cpp:83 + cudnn_version = torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 0 + is_hopper_or_newer = device_capability and (device_capability[0] == 9 or device_capability[0] == 10) + prefer_cudnn = prefer_cudnn and is_hopper_or_newer and cudnn_version > 91500 # cuDNN is enabled by default on SM 9.0/10.0 with cuDNN > 9.15.0 (per #169849) # For older cuDNN versions or other architectures, Flash Attention is preferred @@ -3696,8 +3717,11 @@ def compiled_func(order): times.append(t1 - t0) self.assertTrue(times[0] < times[1], "expected cuDNN SDPA to be faster than Math backend.") self.assertTrue(times[1] > times[2], "expected Eff Attn backend to faster than Math backend.") - self.assertTrue(times[3] < times[2], "expected Flash Attn backend to faster than Math backend.") - self.assertTrue(times[0] < times[2], "expected cuDNN Attn backend to faster than Eff Attn backend.") + if not TEST_WITH_ROCM: + # Skip some checks on ROCM. hipDNN is currently slower than Eff Attn, + # and Flash/Eff are roughly equal. + self.assertTrue(times[3] < times[2], "expected Flash Attn backend to faster than Eff Attn backend.") + self.assertTrue(times[0] < times[2], "expected cuDNN Attn backend to faster than Eff Attn backend.") reset_order = torch._C._get_sdp_priority_order() self.assertEqual(default_order, reset_order, "expected SDPA context manager to reset priority order.") diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index e7116f4eccc4e..dba97348f92b9 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -101,7 +101,9 @@ def evaluate_platform_supports_efficient_attention(): return False def evaluate_platform_supports_cudnn_attention(): - return (not TEST_WITH_ROCM) and SM80OrLater and (TEST_CUDNN_VERSION >= 90000) + if TEST_WITH_ROCM: + return TEST_HIPDNN + return SM80OrLater and (TEST_CUDNN_VERSION >= 90000) def evaluate_platform_supports_green_context(): if IS_WINDOWS: From 77cc7f93bb76bf91e080646aa085272c7516ae69 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Wed, 6 May 2026 12:25:18 -0700 Subject: [PATCH 10/10] [ROCm] Add additional SDPA test coverage Adds `test_fused_attention_custom_scale` parameterized over PLATFORM_SPECIFIC_SDPA (flash, efficient, cudnn). Each fused backend runs SDPA with a non-default `scale=` argument and is compared against the math backend with the same scale. No existing PyTorch SDPA test exercises a non-default scale on a fused/cuDNN backend. Co-Authored-By: Claude Opus 4 (1M context) --- test/test_transformers.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/test_transformers.py b/test/test_transformers.py index 891891f1e961b..132eab9f08caf 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3285,6 +3285,24 @@ def test_cudnn_attention_runtime_disabled(self, device): with sdpa_kernel([SDPBackend.MATH]): self.assertFalse(torch.backends.cuda.can_use_cudnn_attention(params)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") + @parametrize("kernel", PLATFORM_SPECIFIC_SDPA) + def test_fused_attention_custom_scale(self, device, kernel: SDPBackend): + dtype = torch.bfloat16 + make_tensor = partial(torch.rand, device=device, dtype=dtype) + size = SdpaShape(2, 4, 128, 64) + query, key, value = make_tensor(size), make_tensor(size), make_tensor(size) + # Custom scale that differs from the default 1/sqrt(head_dim). + scale = 0.05 + + with sdpa_kernel(backends=[SDPBackend.MATH]): + math_ref = F.scaled_dot_product_attention(query, key, value, scale=scale) + + with sdpa_kernel(backends=[kernel]): + actual = F.scaled_dot_product_attention(query, key, value, scale=scale) + + self.assertEqual(actual, math_ref, atol=2e-2, rtol=2e-2) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]):